In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
SERVER = 1

if not SERVER:
    %cd /home/xabush/code/snet/moses-incons-pen-xp/notebooks/variable_selection/cancer/nn

else:
    %cd /home/abdu/bio_ai/moses-incons-pen-xp/notebooks/variable_selection/cancer/nn

import pandas as pd
import matplotlib.pyplot as plt
import tensorflow_probability.substrates.jax as tfp
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold, KFold, GridSearchCV
from sklearn.preprocessing import StandardScaler
tfd = tfp.distributions
import jax
import haiku as hk
import numpy as np
import optax
from nn_util import *
from optim_util import *
from bnn_models import *
from train_utils import *
from data_utils import *
from hpo_util import *
plt.style.use('ggplot')
%load_ext autoreload

/home/abdu/bio_ai/moses-incons-pen-xp/notebooks/variable_selection/cancer/nn


In [2]:
if SERVER:
    data_dir = "/home/abdu/bio_ai/moses-incons-pen-xp/data"
else:
    data_dir = "/home/xabush/code/snet/moses-incons-pen-xp/data"

### GDSC Cell Line

#### Tamoxifen

In [3]:
gdsc_dir = f"{data_dir}/cell_line/gdsc2"
gdsc_exp_tamox_data = pd.read_csv(f"{gdsc_dir}/tamoxifen_response_gene_expr.csv")
gdsc_exp_tamox_data.shape

(406, 37265)

In [4]:
X, target = gdsc_exp_tamox_data.iloc[:,:-1], gdsc_exp_tamox_data.iloc[:,-1]
# change to -log10(IC_50) to make it comparable
target = -np.log10(np.exp(target)) # exp b/c the values are natural logs of raw IC_50

In [5]:
cancer_driver_genes_df = pd.read_csv(f"{data_dir}/cell_line/driver_genes_20221018.csv")
cols = X.columns.to_list()
driver_syms = cancer_driver_genes_df["symbol"].to_list()
sym_list = [sym.strip() for sym in cols if sym in driver_syms]

In [6]:
X_selected = X[sym_list]
X_selected.shape

(406, 768)

#### Data Preprocessing

In [72]:
from sklearn.preprocessing import QuantileTransformer, PowerTransformer, RobustScaler, MinMaxScaler, Normalizer, StandardScaler

seed = 745
# transformer = QuantileTransformer(random_state=seed, output_distribution="normal")
transformer = MinMaxScaler()
X_train_outer, X_train, X_val, X_test, \
y_train_outer, y_train, y_val, y_test, (train_indices, val_indices) = preprocess_data(seed, X_selected, target,
                                                                                      transformer, val_size=0.2, test_size=0.2)

In [73]:
from scipy.sparse import csgraph
J = np.load(f"{data_dir}/cell_line/cancer_genes_net.npy")
L = csgraph.laplacian(J, normed=True)
J_zeros  = np.zeros_like(J)

# J, L, J_zeros = jax.device_put(J, gpu_id), jax.device_put(L, gpu_id), jax.device_put(J_zeros, gpu_id)

##### NN Model

In [74]:
%autoreload
optuna.logging.set_verbosity(optuna.logging.INFO)
sampler = optuna.samplers.TPESampler()
study = optuna.create_study(sampler=sampler)
init_fn = hk.initializers.VarianceScaling(2.0, "fan_in",  "truncated_normal")
study.optimize(lambda trial: objective_resnet_bg(trial, seed, X_train, X_val, y_train, y_val,
                                              [1, 1, 1, 1], [128, 128, 128, 128], init_fn, "swish", J_zeros, bg=False), timeout=300)

[32m[I 2023-02-16 08:19:37,433][0m A new study created in memory with name: no-name-7df1ff8e-084c-4c14-be88-9b5dfba8581f[0m
[32m[I 2023-02-16 08:20:04,892][0m Trial 0 finished with value: 0.5058743953704834 and parameters: {'lr_0': 0.001, 'disc_lr_0': 0.1, 'weight_decay': 0.0001053258080235261, 'block_type': 'PreActResNet', 'mu': 20.563110144504563}. Best is trial 0 with value: 0.5058743953704834.[0m
[32m[I 2023-02-16 08:20:32,052][0m Trial 1 finished with value: 0.5054532885551453 and parameters: {'lr_0': 0.001, 'disc_lr_0': 0.1, 'weight_decay': 0.0005957036566130343, 'block_type': 'PreActResNet', 'mu': 69.07861485369}. Best is trial 1 with value: 0.5054532885551453.[0m
[32m[I 2023-02-16 08:20:54,378][0m Trial 2 finished with value: 0.5296558141708374 and parameters: {'lr_0': 0.01, 'disc_lr_0': 0.5, 'weight_decay': 6.148191361822473e-08, 'block_type': 'ResNet', 'mu': 58.368903161134796}. Best is trial 1 with value: 0.5054532885551453.[0m
[32m[I 2023-02-16 08:21:20,961][0

In [75]:
resnet_config = study.best_params
print(resnet_config)

{'lr_0': 0.1, 'disc_lr_0': 0.01, 'weight_decay': 0.035718927454022706, 'block_type': 'PreActResNet', 'mu': 96.36166945673435}


In [76]:
from optuna.visualization import plot_param_importances
plot_param_importances(study)

In [78]:
%autoreload
rng_key = jax.random.PRNGKey(seed)
epochs = 200
num_cycles = 10
lr_0 = resnet_config["lr_0"]
disc_lr_0 = resnet_config["disc_lr_0"]
# lr_0 = 0.005
hidden_sizes = [128, 128, 128, 128]
num_blocks = [1, 1, 1, 1]
weight_decay = resnet_config["weight_decay"]
block_type = resnet_config["block_type"]
# dropout_rate = resnet_config["dropout_rate"]
# dropout_rate = 0.5
dropout_rate = 0.0
eta, mu = 1.0, resnet_config["mu"]


init_fn = hk.initializers.VarianceScaling(2.0, "fan_in",  "truncated_normal")
torch.manual_seed(seed)
data_loader = NumpyLoader(NumpyData(X_train_outer, y_train_outer), batch_size=32,
                          shuffle=True)
bnn_model, state, val_losses = train_resnet_bg_model(rng_key, data_loader, epochs, num_cycles, 1, lr_0, disc_lr_0,
                                                     block_type, num_blocks, hidden_sizes,
                                                     init_fn, weight_decay, "swish", dropout_rate,
                                                     eta, mu, J_zeros)

print(len(state))
rmse_train, r2_train = eval_resnet_bg_model(rng_key, bnn_model, X_train_outer, y_train_outer, state, False, True)
rmse_test, r2_test = eval_resnet_bg_model(rng_key, bnn_model, X_test, y_test, state, False, True)
print(f"Trian RMSE: {rmse_train}, r2_score: {r2_train}")
print(f"Test RMSE: {rmse_test}, r2_score: {r2_test}")

100%|█████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:36<00:00,  5.41it/s]


10
Trian RMSE: 0.4657616913318634, r2_score: 0.2594581730163559
Test RMSE: 0.5990235209465027, r2_score: 0.011090956030352261


In [None]:
%autoreload
optuna.logging.set_verbosity(optuna.logging.INFO)
sampler = optuna.samplers.TPESampler()
study_bg= optuna.create_study(sampler=sampler)
init_fn = hk.initializers.VarianceScaling(2.0, "fan_in",  "truncated_normal")
study_bg.optimize(lambda trial: objective_resnet_bg(trial, seed, X_train, X_val, y_train, y_val,
                                              init_fn, "swish", L, bg=True), timeout=600)

[32m[I 2023-02-16 08:50:25,589][0m A new study created in memory with name: no-name-1c005cdf-fa29-4626-aa4a-7784afb1df9c[0m
[32m[I 2023-02-16 08:50:51,873][0m Trial 0 finished with value: 0.5483154654502869 and parameters: {'lr_0': 0.01, 'disc_lr_0': 0.01, 'weight_decay': 0.21648857848192798, 'block_type': 'PreActResNet', 'num_blocks': 4, 'block_size': 128, 'eta': 10.537290417020714, 'mu': 82.47332490357681}. Best is trial 0 with value: 0.5483154654502869.[0m
[32m[I 2023-02-16 08:51:12,696][0m Trial 1 finished with value: 0.5278322696685791 and parameters: {'lr_0': 0.01, 'disc_lr_0': 0.01, 'weight_decay': 1.7834261300484228e-06, 'block_type': 'ResNet', 'num_blocks': 4, 'block_size': 128, 'eta': -79.89897934906696, 'mu': 26.339595686732242}. Best is trial 1 with value: 0.5278322696685791.[0m
[32m[I 2023-02-16 08:51:42,630][0m Trial 2 finished with value: 0.517764151096344 and parameters: {'lr_0': 0.01, 'disc_lr_0': 0.5, 'weight_decay': 0.00024633140167672636, 'block_type': 'P

In [95]:
resnet_bg_config = study_bg.best_params
resnet_bg_config

{'lr_0': 0.1,
 'disc_lr_0': 0.1,
 'weight_decay': 0.07186679207779957,
 'block_type': 'PreActResNet',
 'num_blocks': 3,
 'block_size': 64,
 'eta': 62.91174876310573,
 'mu': 79.7892485098757}

In [96]:
from optuna.visualization import plot_param_importances
plot_param_importances(study_bg)

In [102]:
%autoreload
rng_key = jax.random.PRNGKey(seed)
epochs = 200
num_cycles = 20
lr_0 = 0.01
disc_lr_0 = 0.5
# lr_0 = 0.005
blocks = [1 for _ in range(resnet_bg_config["num_blocks"])]
hidden_sizes = [resnet_bg_config["block_size"] for _ in range(resnet_bg_config["num_blocks"])]
weight_decay = resnet_bg_config["weight_decay"]
block_type = resnet_bg_config["block_type"]
# dropout_rate = resnet_bg_config["dropout_rate"]
# dropout_rate = 0.5
dropout_rate = 0.0
eta, mu = -resnet_bg_config["eta"], resnet_bg_config["mu"]


init_fn = hk.initializers.VarianceScaling(2.0, "fan_in",  "truncated_normal")
torch.manual_seed(seed)
data_loader = NumpyLoader(NumpyData(X_train_outer, y_train_outer), batch_size=32,
                          shuffle=True)
bnn_model, state, val_losses = train_resnet_bg_model(rng_key, data_loader, epochs, num_cycles, 1, lr_0, disc_lr_0,
                                              block_type, num_blocks, hidden_sizes,
                                              init_fn, weight_decay, "swish", dropout_rate,
                                              eta, mu, L)

print(len(state))
rmse_train, r2_train = eval_resnet_bg_model(rng_key, bnn_model, X_train_outer, y_train_outer, state, False, True)
rmse_test, r2_test = eval_resnet_bg_model(rng_key, bnn_model, X_test, y_test, state, False, True)
print(f"Trian RMSE: {rmse_train}, r2_score: {r2_train}")
print(f"Test RMSE: {rmse_test}, r2_score: {r2_test}")

100%|█████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:28<00:00,  7.05it/s]


20
Trian RMSE: 0.3809589743614197, r2_score: 0.5045742883769408
Test RMSE: 0.5647943019866943, r2_score: 0.12087810678531419


In [16]:

def find_feats_on_graph(feat_idx, J):
    G = np.zeros((len(feat_idx), len(feat_idx)))
    for i, f1 in enumerate(feat_idx):
        for j, f2 in enumerate(feat_idx):
            if f1 != f2:
                G[i, j] = J[f1, f2]

    return G

In [17]:
np.count_nonzero(find_feats_on_graph(np.argsort(disc_mean)[::-1][:30], J))

10