In [1]:
import pandas as pd
import numpy as np
import pickle
import torch
import torch.nn as nn
import math

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torch.utils.data import TensorDataset, DataLoader
import yaml
import pprint
import os
import wandb
import sys
import random
from scipy.stats import mode, pearsonr
import pickle
import itertools
# sys.path.append("./src/")
import sys

from src.gaussian_multinomial_diffusion import GaussianMultinomialDiffusion
from src.modules import MLPDiffusion
from src.vae_model import vae
from src.loss_functions import get_kld_loss, coral
from model_definition import *

In [3]:
fold = 0
drug2consider = "TEMOZOLOMIDE"

In [4]:
# global variables

CONFIG_PATH = f"experiment_settings_yaml/model_config_2A_annotated_mutations_v7_fold{fold}.yaml" # model config path
pretty_print = pprint.PrettyPrinter()
print(f"Loading config from {CONFIG_PATH}")
config = yaml.safe_load(open(CONFIG_PATH))
model_config = config["model_hyperparams"]
folder_config = config["folder_config"]
wandb_config = config["wandb_config"]
wandb_config["project_name"] = wandb_config["project_name"] + f"-{model_config['experiment_id']}-{model_config['experiment_settings']}-fold{model_config['sample_id']}" # updates wandb project name for ease of monitoring and logging.
device = torch.device(f"cuda:{model_config['device']}" if torch.cuda.is_available() else "cpu")
genes_324 = list(pd.read_csv(f"{folder_config['data_folder']}/raw/metadata/gene2ind.txt", header=None)[0])
drug_fp = pd.read_csv(f"{folder_config['data_folder']}/raw/metadata/drug_morgan_fingerprints.csv", index_col=0)
suffixes = ["_piu_max", "_piu_sum", "_piu_mean", "_piu_count",
            "_lu_max", "_lu_sum", "_lu_mean", "_lu_count",
            "_ncu_max", "_ncu_sum", "_ncu_mean", '_ncu_count',
            "_pathogenic_max", "_pathogenic_sum", "_pathogenic_mean", "_pathogenic_count",
            "_vus_max", "_vus_sum", "_vus_mean", "_vus_count",
            "_benign_max", "_benign_sum", "_benign_mean", "_benign_count"
           ]
genes_7776 = []
for s in suffixes:
    for g in list(pd.read_csv(f"{folder_config['data_folder']}/raw/metadata/gene2ind.txt", header=None)[0]):
        genes_7776.append(f"{g}{s}")

# # setting up wandb
# os.environ["WANDB_CACHE_DIR"] = wandb_config["wandb_cache_dir"]
# os.environ["WANDB_DIR"] = wandb_config["wandb_cache_dir"]
# wandb.login(key=wandb_config["api_key"])

# seeding
torch.manual_seed(model_config["seed"])
random.seed(model_config["seed"])
np.random.seed(model_config["seed"])
# reproducibility in data loading - https://pytorch.org/docs/stable/notes/randomness.html
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(model_config["seed"])



Loading config from experiment_settings_yaml/model_config_2A_annotated_mutations_v7_fold0.yaml


<torch._C.Generator at 0x7f8f3f053ad0>

In [5]:
# pass samples through the VAE and DDPM network, till just before VAE decoder
def vae_decoder_input(df, vae, diff_model):
    """
    Takes input df, pretrained vae and diffusion model as inputs, runs forward pass till VAE decoder
    """
    with torch.no_grad():
        batch = torch.tensor(df.values) # convert to torch tensor
        inp_vae = batch.to(device, dtype=torch.float32)
        inp, mu, logvar, _ = vae(inp_vae) # From VAE encoder + reparameterization
        
        noise = torch.randn_like(inp) # this is the label we use   
        b = inp.shape[0]
        t = (torch.ones((b,)) * 700).long().to(device) # fixing time steps to 700
        pt = torch.ones_like(t).float() / diff_model.num_timesteps
        inp_t = diff_model.gaussian_q_sample(inp, t, noise) # forward process with cell line model encoder
        
        model_out = diff_model._denoise_fn(inp_t, t) # predicted epsilon from patient decoder
    
        # predict inp from noise using patient model
        inp_pred = diff_model._predict_xstart_from_eps(inp_t, t, model_out)

    return inp_pred.detach().cpu().numpy()

def load_pretrained_models():
    pretrained = torch.load(f"{folder_config['model_checkpoint_folder']}/best_pretrained_validation_loss_{model_config['model_save_criteria']}_{model_config['experiment_id']}_{model_config['experiment_settings']}_fold{model_config['sample_id']}.pth")
    is_real = True if model_config["input_data_type"] == "binary_mutations" else False
    # patients
    patient_vae = vae(input_dim=model_config["feature_num"], k_list=model_config["patient_vae_k_list"], actf_list=model_config["patient_vae_actf_list"], is_real=is_real).to(device)
    tcga_mlp_diffusion_model = MLPDiffusion(d_in=model_config["patient_vae_k_list"][-1]//2, num_classes=0, is_y_cond=False, rtdl_params={"d_layers": [model_config["patient_vae_k_list"][-1]//4], "dropout": model_config["dropout"]}).to(device)
    tcga_diff_model = GaussianMultinomialDiffusion(num_classes=np.array([0]), num_numerical_features=model_config["patient_vae_k_list"][-1]//2, denoise_fn=tcga_mlp_diffusion_model, device=device)#.to(device)
    tcga_diff_model.load_state_dict(pretrained["patient_diff_model"])
    patient_vae.load_state_dict(pretrained["patient_vae_conditioned"])
    # cell lines
    cl_vae = vae(input_dim=model_config["feature_num"], k_list=model_config["cl_vae_k_list"], actf_list=model_config["cl_vae_actf_list"], is_real=is_real).to(device)
    cl_mlp_diffusion_model = MLPDiffusion(d_in=model_config["cl_vae_k_list"][-1]//2, num_classes=0, is_y_cond=False, rtdl_params={"d_layers": [model_config["cl_vae_k_list"][-1]//4], "dropout": model_config["dropout"]}).to(device)
    cl_diff_model = GaussianMultinomialDiffusion(num_classes=np.array([0]), num_numerical_features=model_config["cl_vae_k_list"][-1]//2, denoise_fn=cl_mlp_diffusion_model, device=device)#.to(device)
    cl_diff_model.load_state_dict(pretrained["cl_diff_model"])
    cl_vae.load_state_dict(pretrained["cl_vae_conditioned"])
    return cl_diff_model, cl_vae, tcga_diff_model, patient_vae

def load_datasets(sample_id):
    """
    Takes sample_id as input, loads source and target train, validation and test splits (predefined files from Processing folder).
    """
    data_dir = folder_config["data_folder"] + "input_types/"
    # navigate based on input type
    if model_config["input_data_type"] == "binary_mutations":
        data_dir = data_dir + "raw_mutations/"
        features2select = genes_324  # inclusive of Morgan drug fingerprints of 2048 dim
    elif model_config["input_data_type"] == "annotated_mutations":
        data_dir = data_dir + "annotated_mutations/"
        features2select = genes_7776  # inclusive of Morgan drug fingerprints of 2048 dim
    elif model_config["input_data_type"] == "transformer_inputs": # processed by PREDICT-AI transformer embedder
        data_dir = data_dir + "transformer_inputs_transformed_797/"
        features2select = [f"transformer_embedded_{i}" for i in range(797)] # after transformer embedding
    else:
        print("Unsupported input type!")
        return
    
    # navigate based on experiment id
    if model_config["experiment_id"] == "1A":
        data_dir = data_dir + "Experiment1/SettingA/"
    elif model_config["experiment_id"] == "1B":
        data_dir = data_dir + "Experiment1/SettingB/"
    elif model_config["experiment_id"] == "2A":
        data_dir = data_dir + "Experiment2/SettingA/"
    elif model_config["experiment_id"] == "2B":
        data_dir = data_dir + "Experiment2/SettingB/"
    else:
        print("Unsupported experiment ID!")
        return
    
    # load the fold based on sample_id - Note: cell lines have only 1 fold (fold 0)
    with open(f"{data_dir}/cell_lines_fold0_processed.pkl", "rb") as f:
        source_data = pickle.load(f)

    with open(f"{data_dir}/patients_fold{sample_id}_processed.pkl", "rb") as f:
        target_data = pickle.load(f)

    # load pretrained TCGA VAE and diffusion models
    # pass data points through patient DDPM and get the input to VAE decoder for DRP
    cl_diff_model, cl_vae, tcga_diff_model, patient_vae = load_pretrained_models()
    
    # select data based on experiment settings 
    # Can be CISPLATIN, PACLITAXEL, FLUOROURACIL, SORAFENIB for 1A, CISPLATIN, TCGA-CESC; CISPLATIN, TCGA-HNSC; PACLITAXEL, TCGA-BRCA; FLUOROURACIL, TCGA-STAD for 1B
    # ALL for 2A, TCGA-BRCA, TCGA-CESC, TCGA-HNSC, TCGA-STAD for 2B
    if model_config["experiment_id"] in ["1A", "2B"]:
        setting = model_config["experiment_settings"]
        train_source_data, val_source_data, test_source_data = source_data["train"][setting], source_data["val"][setting], source_data["test"][setting]
        train_target_data, val_target_data, test_target_data = target_data["train"][setting], target_data["val"][setting], target_data["test"][setting]
    elif model_config["experiment_id"] == "1B":
        setting = (model_config["experiment_settings"].split(", ")[0], model_config["experiment_settings"].split(", ")[1], "TCGA")
        train_source_data, val_source_data, test_source_data = source_data["train"][setting], source_data["val"][setting], source_data["test"][setting]
        train_target_data, val_target_data, test_target_data = target_data["train"][setting], target_data["val"][setting], target_data["test"][setting]
    elif model_config["experiment_id"] == "2A":
        train_source_data, val_source_data, test_source_data = source_data["train"], source_data["val"], source_data["test"]
        train_target_data, val_target_data, test_target_data = target_data["train"], target_data["val"], target_data["test"]
    else:
        print("Unsupported experiment settings and ID")
        return
    
    # merge dataframes with drug Morgan fingprint dataframes
    train_source_data_merged = train_source_data.merge(drug_fp, left_on="drug_name", right_on=drug_fp.index)
    val_source_data_merged = val_source_data.merge(drug_fp, left_on="drug_name", right_on=drug_fp.index)
    test_source_data_merged = test_source_data.merge(drug_fp, left_on="drug_name", right_on=drug_fp.index)

    train_target_data_merged = train_target_data.merge(drug_fp, left_on="drug_name", right_on=drug_fp.index)
    val_target_data_merged = val_target_data.merge(drug_fp, left_on="drug_name", right_on=drug_fp.index)
    test_target_data_merged = test_target_data.merge(drug_fp, left_on="drug_name", right_on=drug_fp.index)

    assert train_source_data_merged.shape[0] == train_source_data.shape[0], "Train source data loss after merge!"
    assert val_source_data_merged.shape[0] == val_source_data.shape[0], "Val source data loss after merge!"
    assert test_source_data_merged.shape[0] == test_source_data.shape[0], "Test source data loss after merge!"
    assert train_target_data_merged.shape[0] == train_target_data.shape[0], "Train target data loss after merge!"
    assert val_target_data_merged.shape[0] == val_target_data.shape[0], "Val target data loss after merge!"
    assert test_target_data_merged.shape[0] == test_target_data.shape[0], "Test target data loss after merge!"

    # separate out into input, drug and labels
    train_source_inputs, val_source_inputs, test_source_inputs = train_source_data_merged[features2select], val_source_data_merged[features2select], test_source_data_merged[features2select]
    # pass cl samples through cl diff model and vae
    train_source_inputs_vae = pd.DataFrame(vae_decoder_input(train_source_inputs, cl_vae, cl_diff_model), columns=[f"vae_feat{i}" for i in range(model_config["cl_vae_k_list"][-1]//2)], index=train_source_data_merged.index)
    val_source_inputs_vae = pd.DataFrame(vae_decoder_input(val_source_inputs, cl_vae, cl_diff_model), columns=[f"vae_feat{i}" for i in range(model_config["cl_vae_k_list"][-1]//2)], index=val_source_data_merged.index)
    test_source_inputs_vae = pd.DataFrame(vae_decoder_input(test_source_inputs, cl_vae, cl_diff_model), columns=[f"vae_feat{i}" for i in range(model_config["cl_vae_k_list"][-1]//2)], index=test_source_data_merged.index)
    train_source_drugs, val_source_drugs, test_source_drugs = train_source_data_merged[[str(i) for i in range(0, 2048)]].values, val_source_data_merged[[str(i) for i in range(0, 2048)]].values, test_source_data_merged[[str(i) for i in range(0, 2048)]].values
    train_source_labels, val_source_labels, test_source_labels = train_source_data_merged["auc"].values, val_source_data_merged["auc"].values, test_source_data_merged["auc"].values

    train_target_inputs, val_target_inputs, test_target_inputs = train_target_data_merged[features2select], val_target_data_merged[features2select], test_target_data_merged[features2select]
    # pass patient samples through tcga diff model and vae
    train_target_inputs_vae = pd.DataFrame(vae_decoder_input(train_target_inputs, patient_vae, tcga_diff_model), columns=[f"vae_feat{i}" for i in range(model_config["patient_vae_k_list"][-1]//2)], index=train_target_data_merged.sample_id)
    val_target_inputs_vae = pd.DataFrame(vae_decoder_input(val_target_inputs, patient_vae, tcga_diff_model), columns=[f"vae_feat{i}" for i in range(model_config["patient_vae_k_list"][-1]//2)], index=val_target_data_merged.sample_id)
    test_target_inputs_vae = pd.DataFrame(vae_decoder_input(test_target_inputs, patient_vae, tcga_diff_model), columns=[f"vae_feat{i}" for i in range(model_config["patient_vae_k_list"][-1]//2)], index=test_target_data_merged.sample_id)

    train_target_drugs, val_target_drugs, test_target_drugs = train_target_data_merged[[str(i) for i in range(0, 2048)]].values, val_target_data_merged[[str(i) for i in range(0, 2048)]].values, test_target_data_merged[[str(i) for i in range(0, 2048)]].values
    train_target_labels, val_target_labels, test_target_labels = train_target_data_merged["recist"].values, val_target_data_merged["recist"].values, test_target_data_merged["recist"].values

    return train_source_inputs_vae, train_source_drugs, train_source_labels, val_source_inputs_vae, val_source_drugs, val_source_labels, test_source_inputs_vae, test_source_drugs, test_source_labels, train_target_inputs_vae, train_target_drugs, train_target_labels, val_target_inputs_vae, val_target_drugs, val_target_labels, test_target_inputs_vae, test_target_drugs, test_target_labels, train_target_data_merged, val_target_data_merged, test_target_data_merged, train_source_data_merged, val_source_data_merged, test_source_data_merged

    # pass # needs to return (train_source_data, train_source_labels, val_source_data, val_source_labels, test_source_data, test_source_labels), (train_target_data, train_target_labels, val_target_data, val_target_labels, test_target_data, test_target_labels)
    #  Dummy data
    # train_source_data, val_source_data, test_source_data = np.random.rand(32, 2048 + 4), np.random.rand(10, 2048 + 4), np.random.rand(5, 2048 + 4)
    # train_source_labels, val_source_labels, test_source_labels = np.random.randint(2, size=32), np.random.randint(2, size=10), np.random.randint(2, size=5)
    # train_target_data, val_target_data, test_target_data = np.random.rand(32, 2048 + 4), np.random.rand(10, 2048 + 4), np.random.rand(3, 2048 + 4)
    # train_target_labels, val_target_labels, test_target_labels = np.random.randint(2, size=32), np.random.randint(2, size=10), np.random.randint(2, size=3)
    # return train_source_data, train_source_labels, val_source_data, val_source_labels, test_source_data, test_source_labels, train_target_data, train_target_labels, val_target_data, val_target_labels, test_target_data, test_target_labels

def load_augmented_cl_dataset(sample_id):
    augmented_cl_df = pd.read_csv(f"{folder_config['model_checkpoint_folder']}/augmented_cl_clconditioned_uda_v2_vaeinput_{model_config['model_save_criteria']}_{model_config['experiment_id']}_{model_config['experiment_settings']}_fold{model_config['sample_id']}.csv", index_col=0)
    print(f"Loaded augmented CL data: {augmented_cl_df.shape}")
    return augmented_cl_df

In [6]:
cl_diff_model, cl_vae, tcga_diff_model, patient_vae = load_pretrained_models()

U: encoder 
Sequential(
  (enc-0): Linear(in_features=797, out_features=512, bias=True)
  (act-0): Tanh()
  (enc-1): Linear(in_features=512, out_features=128, bias=True)
  (act-1): ReLU()
)
#
mu_layer: 
Linear(in_features=128, out_features=64, bias=True)
#
sigma_layer: 
Linear(in_features=128, out_features=64, bias=True)
#
U: decoder 
Sequential(
  (-dec-0): Linear(in_features=64, out_features=128, bias=True)
  (-act-0): Tanh()
  (dec-0): Linear(in_features=128, out_features=512, bias=True)
  (act-0): Tanh()
  (dec-1): Linear(in_features=512, out_features=797, bias=True)
  (act-1): Sigmoid()
)
U: encoder 
Sequential(
  (enc-0): Linear(in_features=797, out_features=1024, bias=True)
  (act-0): Tanh()
  (enc-1): Linear(in_features=1024, out_features=128, bias=True)
  (act-1): Tanh()
)
#
mu_layer: 
Linear(in_features=128, out_features=64, bias=True)
#
sigma_layer: 
Linear(in_features=128, out_features=64, bias=True)
#
U: decoder 
Sequential(
  (-dec-0): Linear(in_features=64, out_features=

In [7]:
train_source_inputs_vae, train_source_drugs, train_source_labels, val_source_inputs_vae, val_source_drugs, val_source_labels, test_source_inputs_vae, test_source_drugs, test_source_labels, train_target_inputs_vae, train_target_drugs, train_target_labels, val_target_inputs_vae, val_target_drugs, val_target_labels, test_target_inputs_vae, test_target_drugs, test_target_labels, train_target_data_merged, val_target_data_merged, test_target_data_merged, train_source_data_merged, val_source_data_merged, test_source_data_merged = load_datasets(sample_id=fold)

U: encoder 
Sequential(
  (enc-0): Linear(in_features=797, out_features=512, bias=True)
  (act-0): Tanh()
  (enc-1): Linear(in_features=512, out_features=128, bias=True)
  (act-1): ReLU()
)
#
mu_layer: 
Linear(in_features=128, out_features=64, bias=True)
#
sigma_layer: 
Linear(in_features=128, out_features=64, bias=True)
#
U: decoder 
Sequential(
  (-dec-0): Linear(in_features=64, out_features=128, bias=True)
  (-act-0): Tanh()
  (dec-0): Linear(in_features=128, out_features=512, bias=True)
  (act-0): Tanh()
  (dec-1): Linear(in_features=512, out_features=797, bias=True)
  (act-1): Sigmoid()
)
U: encoder 
Sequential(
  (enc-0): Linear(in_features=797, out_features=1024, bias=True)
  (act-0): Tanh()
  (enc-1): Linear(in_features=1024, out_features=128, bias=True)
  (act-1): Tanh()
)
#
mu_layer: 
Linear(in_features=128, out_features=64, bias=True)
#
sigma_layer: 
Linear(in_features=128, out_features=64, bias=True)
#
U: decoder 
Sequential(
  (-dec-0): Linear(in_features=64, out_features=

In [8]:
train_source_data_merged

Unnamed: 0,transformer_embedded_0,transformer_embedded_1,transformer_embedded_2,transformer_embedded_3,transformer_embedded_4,transformer_embedded_5,transformer_embedded_6,transformer_embedded_7,transformer_embedded_8,transformer_embedded_9,...,2038,2039,2040,2041,2042,2043,2044,2045,2046,2047
0,0.000031,-0.000008,0.000022,0.000049,0.000002,0.000032,0.000032,0.000025,0.000010,0.000030,...,0,0,0,0,0,0,0,0,0,0
1,0.000035,0.000021,0.000010,0.000037,0.000002,-0.000016,-0.000001,0.000030,0.000013,0.000022,...,0,0,0,0,0,0,0,0,0,0
2,0.000010,0.000035,0.000026,0.000008,0.000032,0.000002,0.000018,0.000027,0.000017,0.000007,...,0,0,0,0,0,0,0,0,0,0
3,0.000010,0.000026,0.000003,0.000023,0.000007,0.000021,0.000018,0.000034,0.000015,0.000018,...,0,0,0,0,0,0,0,0,0,0
4,0.000025,0.000006,-0.000016,0.000047,0.000020,0.000010,0.000020,0.000047,0.000035,0.000031,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
156436,0.000035,0.000006,0.000002,0.000030,0.000012,0.000031,0.000036,0.000023,0.000031,0.000031,...,0,0,0,0,0,0,0,0,0,0
156437,0.000021,0.000045,0.000008,0.000016,0.000011,0.000022,0.000019,-0.000004,0.000031,0.000031,...,0,0,0,0,0,0,0,0,0,0
156438,0.000010,0.000035,0.000039,0.000011,0.000032,0.000035,0.000024,0.000053,0.000031,0.000031,...,0,0,0,0,0,0,0,0,0,0
156439,0.000024,0.000035,0.000046,0.000033,0.000024,0.000026,0.000027,0.000031,0.000031,0.000031,...,0,0,0,0,0,0,0,0,0,0


In [9]:
train_source_data_merged[["sample_id", "drug_name", "auc"]]

Unnamed: 0,sample_id,drug_name,auc
0,PR-132fPs,DOCETAXEL,0.191876
1,PR-L3QLdq,ELEPHANTIN,0.940458
2,PR-NxSV8u,MITOXANTRONE,0.921925
3,PR-oLPbwB,DACTINOMYCIN,0.179515
4,PR-4ngqZx,CCT007093,0.989986
...,...,...,...
156436,PR-M4505H,PFI-1,0.919051
156437,PR-Bz57NU,NILOTINIB,0.995489
156438,PR-6SyWYo,SAPITINIB,0.492491
156439,PR-wGySam,TASELISIB,0.901939


In [10]:
train_target_data_merged

Unnamed: 0,transformer_embedded_0,transformer_embedded_1,transformer_embedded_2,transformer_embedded_3,transformer_embedded_4,transformer_embedded_5,transformer_embedded_6,transformer_embedded_7,transformer_embedded_8,transformer_embedded_9,...,2038,2039,2040,2041,2042,2043,2044,2045,2046,2047
0,0.000022,0.000007,-0.000017,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,...,0,0,0,0,0,0,0,0,0,0
1,0.000025,0.000014,0.000004,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,...,0,0,0,0,0,0,0,0,0,0
2,0.000031,-0.000001,0.000038,0.000021,0.000044,0.000031,0.000031,0.000031,0.000031,0.000031,...,0,0,0,0,0,1,0,0,0,0
3,0.000025,0.000035,-0.000017,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,...,0,0,0,0,0,0,0,0,0,0
4,0.000039,-0.000018,0.000027,0.000035,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
483,0.000038,0.000026,0.000044,0.000035,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,...,0,0,0,0,0,0,0,0,0,0
484,0.000029,-0.000009,0.000038,0.000033,0.000031,0.000044,0.000031,0.000032,0.000020,-0.000008,...,0,0,0,0,0,0,0,0,0,0
485,0.000031,0.000035,0.000048,0.000001,0.000003,0.000002,0.000031,0.000031,0.000031,0.000031,...,0,0,0,0,0,0,0,0,0,0
486,-0.000016,0.000032,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,...,0,0,0,0,0,1,0,0,0,0


In [11]:
train_target_data_merged.filter(regex="transformer_embedded_*")

Unnamed: 0,transformer_embedded_0,transformer_embedded_1,transformer_embedded_2,transformer_embedded_3,transformer_embedded_4,transformer_embedded_5,transformer_embedded_6,transformer_embedded_7,transformer_embedded_8,transformer_embedded_9,...,transformer_embedded_787,transformer_embedded_788,transformer_embedded_789,transformer_embedded_790,transformer_embedded_791,transformer_embedded_792,transformer_embedded_793,transformer_embedded_794,transformer_embedded_795,transformer_embedded_796
0,0.000022,0.000007,-0.000017,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,...,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024
1,0.000025,0.000014,0.000004,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,...,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024
2,0.000031,-0.000001,0.000038,0.000021,0.000044,0.000031,0.000031,0.000031,0.000031,0.000031,...,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024
3,0.000025,0.000035,-0.000017,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,...,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024
4,0.000039,-0.000018,0.000027,0.000035,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,...,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
483,0.000038,0.000026,0.000044,0.000035,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,...,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024
484,0.000029,-0.000009,0.000038,0.000033,0.000031,0.000044,0.000031,0.000032,0.000020,-0.000008,...,0.000022,0.000022,0.000022,0.000022,0.000022,0.000022,0.000022,0.000022,0.000022,0.000022
485,0.000031,0.000035,0.000048,0.000001,0.000003,0.000002,0.000031,0.000031,0.000031,0.000031,...,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024
486,-0.000016,0.000032,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,0.000031,...,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024,0.000024


In [12]:
train_target_data_merged[["sample_id", "drug_name", "recist"]]

Unnamed: 0,sample_id,drug_name,recist
0,TCGA-DB-A64P,TEMOZOLOMIDE,0
1,TCGA-S9-A89V,TEMOZOLOMIDE,0
2,P-0001324-T01-IM3,SORAFENIB,0
3,TCGA-S9-A6U8,CARMUSTINE,0
4,TCGA-CN-4731,CETUXIMAB,0
...,...,...,...
483,s_DS_bkm_008_T,BUPARLISIB,0
484,TCGA-GN-A8LK,CARBOPLATIN,0
485,TCGA-VS-A8EJ,CISPLATIN,0
486,P-0002719-T01-IM3,SORAFENIB,0


In [13]:
# create datasets
# Cell Lines
source_dataset_train = TensorDataset(torch.FloatTensor(train_source_data_merged.filter(regex="transformer_embedded_*").values), torch.FloatTensor(train_source_data_merged[[str(i) for i in range(0, 2048)]].values), torch.FloatTensor(train_source_data_merged["auc"].values))
source_dataset_val = TensorDataset(torch.FloatTensor(val_source_data_merged.filter(regex="transformer_embedded_*").values), torch.FloatTensor(val_source_data_merged[[str(i) for i in range(0, 2048)]].values), torch.FloatTensor(val_source_data_merged["auc"].values))
source_dataset_test = TensorDataset(torch.FloatTensor(test_source_data_merged.filter(regex="transformer_embedded_*").values), torch.FloatTensor(test_source_data_merged[[str(i) for i in range(0, 2048)]].values), torch.FloatTensor(test_source_data_merged["auc"].values))

# Patients
target_dataset_train = TensorDataset(torch.FloatTensor(train_target_data_merged.filter(regex="transformer_embedded_*").values), torch.FloatTensor(train_target_data_merged[[str(i) for i in range(0, 2048)]].values), torch.FloatTensor(train_target_data_merged["recist"].values))
target_dataset_val = TensorDataset(torch.FloatTensor(val_target_data_merged.filter(regex="transformer_embedded_*").values), torch.FloatTensor(val_target_data_merged[[str(i) for i in range(0, 2048)]].values), torch.FloatTensor(val_target_data_merged["recist"].values))
target_dataset_test = TensorDataset(torch.FloatTensor(test_target_data_merged.filter(regex="transformer_embedded_*").values), torch.FloatTensor(test_target_data_merged[[str(i) for i in range(0, 2048)]].values), torch.FloatTensor(test_target_data_merged["recist"].values))


In [14]:
# data loaders
source_dataloader_train = DataLoader(source_dataset_train, batch_size = 512, shuffle = True, worker_init_fn = seed_worker, generator = g)
target_dataloader_train = DataLoader(target_dataset_train, batch_size = 512, shuffle = True, worker_init_fn = seed_worker, generator = g)

source_dataloader_val = DataLoader(source_dataset_val, batch_size = 512, shuffle = False, worker_init_fn = seed_worker, generator = g)
target_dataloader_val = DataLoader(target_dataset_val, batch_size = 512, shuffle = False, worker_init_fn = seed_worker, generator = g)

source_dataloader_test = DataLoader(source_dataset_test, batch_size = 512, shuffle = False, worker_init_fn = seed_worker, generator = g)
target_dataloader_test = DataLoader(target_dataset_test, batch_size = 512, shuffle = False, worker_init_fn = seed_worker, generator = g)



In [15]:
class MTL(nn.Module):
    def __init__(self, cl_vae, patient_vae):
        super().__init__()
        self.cl_vae = cl_vae
        self.patient_vae = patient_vae
        self.drug_embedder = nn.Sequential(nn.Linear(2048, 256), nn.ReLU(), nn.Linear(256, 64))
        self.audrc_predictor = nn.Sequential(nn.Linear(64 * 2, 16), nn.ReLU(), nn.Linear(16, 1))
        self.recist_predictor = nn.Sequential(nn.Linear(64 * 2, 16), nn.ReLU(), nn.Linear(16, 1))

    def forward(self, cl_inp, cl_drug, patient_inp, patient_drug, audrc, recist):
        # cl_inp and patient_inp are 797 dim, both drugs are 2048 dim
        cl_inp_emb, _, _, _ = self.cl_vae(cl_inp) # From VAE encoder + reparameterization
        patient_inp_emb, _, _, _ = self.patient_vae(patient_inp)

        cl_drug_emb = self.drug_embedder(cl_drug)
        patient_drug_emb = self.drug_embedder(patient_drug)

        cl_cat = torch.cat((cl_inp_emb, cl_drug_emb), axis = 1)
        patient_cat = torch.cat((patient_inp_emb, patient_drug_emb), axis = 1)

        # recist and audrc prediction
        audrc_pred = self.audrc_predictor(cl_cat)
        recist_pred = self.recist_predictor(patient_cat)

        return cl_cat, patient_cat, audrc_pred, recist_pred
        

In [16]:
# train MTL
mtl_model = MTL(cl_vae, patient_vae).to(device)

In [17]:
optimizer = torch.optim.Adam(params = mtl_model.parameters(), lr=1e-5)

In [18]:
def testing_loop(model, cl_val_loader, patient_val_loader):
    model.eval()
    prediction = []
    true = []
    for idx, batch in enumerate(patient_val_loader):
        patient_inp = batch[0].to(device)
        drug_inp = batch[1].to(device)
        label = batch[2].to(device)
        with torch.no_grad():
            patient_emb, _, _, _ = model.patient_vae(patient_inp)
            drug_emb = model.drug_embedder(drug_inp)
            patient_cat = torch.cat((patient_emb, drug_emb), axis = 1)
            pred = model.recist_predictor(patient_cat)
            prediction.append(pred)
            true.append(label)
    predictions = torch.cat(prediction).view(-1, 1)
    trues = torch.cat(true).view(-1, 1)
    return nn.BCEWithLogitsLoss()(predictions, trues)
            

In [19]:
# training loop
count = 0
patient_val_losses = []
for epoch in range(300):
    mtl_model.train()
    epoch_loss = []
    for idx0, batch0 in enumerate(source_dataloader_train):
        for idx1, batch1 in enumerate(target_dataloader_train):
            optimizer.zero_grad()
            cl_inp, cl_drug, patient_inp, patient_drug, audrc, recist = batch0[0].to(device), batch0[1].to(device), batch1[0].to(device), batch1[1].to(device), batch0[2].to(device), batch1[2].to(device)
            cl_cat, patient_cat, audrc_pred, recist_pred = mtl_model(cl_inp, cl_drug, patient_inp, patient_drug, audrc, recist)

            # align both
            coral_loss = coral(cl_cat, patient_cat)

            # losses
            audrc_loss = nn.MSELoss()(audrc_pred.view(-1, 1), audrc.view(-1, 1))
            recist_loss = nn.BCEWithLogitsLoss()(recist_pred.view(-1, 1), recist.view(-1, 1))

            total_loss = coral_loss + audrc_loss + recist_loss
            total_loss.backward()
            optimizer.step()

            epoch_loss.append(total_loss.cpu().detach().numpy().item())

    # get val loss
    patient_val_loss = testing_loop(mtl_model, source_dataloader_val, target_dataloader_val)
    patient_val_losses.append(patient_val_loss.item())
    print(f"Epoch {epoch}: train loss = {np.mean(epoch_loss)}, val patient loss = {patient_val_loss.item()}")

    if len(patient_val_losses) ==  1:
        best_val_loss = patient_val_loss.item()

    print(f"Best val loss = {best_val_loss}")
    print(f"Current val loss = {patient_val_loss.item()}")

    if patient_val_loss.item() <= best_val_loss: # minimize val loss
        torch.save(mtl_model.state_dict(), f"MTL_model_fold{fold}.pth")
        best_val_loss = patient_val_loss
        print("Saved!")
        count = 0
    else:
        print("Increased count")
        count += 1

    if count >= 3:
        print("Converged")
        break
            
    



Epoch 0: train loss = 1.2585787461474052, val patient loss = 0.6694322228431702
Best val loss = 0.6694322228431702
Current val loss = 0.6694322228431702
Saved!
Epoch 1: train loss = 0.839504068193872, val patient loss = 0.6343693137168884
Best val loss = 0.6694322228431702
Current val loss = 0.6343693137168884
Saved!
Epoch 2: train loss = 0.7013783930173887, val patient loss = 0.6083987951278687
Best val loss = 0.6343693137168884
Current val loss = 0.6083987951278687
Saved!
Epoch 3: train loss = 0.6349455527230805, val patient loss = 0.5585436224937439
Best val loss = 0.6083987951278687
Current val loss = 0.5585436224937439
Saved!
Epoch 4: train loss = 0.5890230041703367, val patient loss = 0.5131404399871826
Best val loss = 0.5585436224937439
Current val loss = 0.5131404399871826
Saved!
Epoch 5: train loss = 0.5622706758041008, val patient loss = 0.5273934602737427
Best val loss = 0.5131404399871826
Current val loss = 0.5273934602737427
Increased count
Epoch 6: train loss = 0.54312438

In [20]:
# run inference on cell line, drug pairs to get pseudolabels
mtl_model_trained = MTL(cl_vae, patient_vae).to(device)

In [21]:
mtl_model_trained.load_state_dict(torch.load(f"MTL_model_fold{fold}.pth"))

<All keys matched successfully>

In [22]:
mtl_model_trained.eval()

MTL(
  (cl_vae): vae(
    (mu_layer): Linear(in_features=128, out_features=64, bias=True)
    (sigma_layer): Linear(in_features=128, out_features=64, bias=True)
    (encoder): Sequential(
      (enc-0): Linear(in_features=797, out_features=1024, bias=True)
      (act-0): Tanh()
      (enc-1): Linear(in_features=1024, out_features=128, bias=True)
      (act-1): Tanh()
    )
    (decoder): Sequential(
      (-dec-0): Linear(in_features=64, out_features=128, bias=True)
      (-act-0): Tanh()
      (dec-0): Linear(in_features=128, out_features=1024, bias=True)
      (act-0): Tanh()
      (dec-1): Linear(in_features=1024, out_features=797, bias=True)
      (act-1): Sigmoid()
    )
  )
  (patient_vae): vae(
    (mu_layer): Linear(in_features=128, out_features=64, bias=True)
    (sigma_layer): Linear(in_features=128, out_features=64, bias=True)
    (encoder): Sequential(
      (enc-0): Linear(in_features=797, out_features=512, bias=True)
      (act-0): Tanh()
      (enc-1): Linear(in_features

In [23]:
# Load augmented cell lines and drug combos
cl_augmented_df = load_augmented_cl_dataset(model_config["sample_id"])
train_val_cell_lines = list(cl_augmented_df.index)
if model_config["experiment_id"] == "1B":
    drugs_with_fp = [model_config["experiment_settings"].split(", ")[0]] # extract out drug name
elif model_config["experiment_id"] == "1A":
    drugs_with_fp = [model_config["experiment_settings"]] # has only drug name
else: # in 2A and 2B include all available drugs with fp
    # drugs_with_fp = list(drug_fp.index)
    drugs_with_fp = [drug2consider]
possible_cl_drug_combinations = list(itertools.product(train_val_cell_lines, drugs_with_fp))
possible_cl_drug_combinations_df = pd.DataFrame(possible_cl_drug_combinations, columns = ["sample_id", "drug_name"])

Loaded augmented CL data: (1193, 64)


In [24]:
len(possible_cl_drug_combinations_df)

1193

In [25]:
# using data loaders to prevent execessive memory usage
class CustomCellLineDataSetUnlabelled(TensorDataset):
    def __init__(self, cl_augmented_df, drug_fp, possible_combinations): # possible_combinations must only consist of samples with drug name with a fingerprint
        self.possible_combinations = possible_combinations
        self.augmented_cl_df = cl_augmented_df
        self.drug_fp = drug_fp

    def __getitem__(self, idx):
        sample_name, drug_name = self.possible_combinations[idx]
        mut_profile = self.augmented_cl_df.loc[sample_name].values
        drug_inp = self.drug_fp.loc[drug_name].values
        return torch.FloatTensor(mut_profile), torch.FloatTensor(drug_inp)

    def __len__(self):
        return len(self.possible_combinations)

In [26]:
cl_aug_train_dataset = CustomCellLineDataSetUnlabelled(cl_augmented_df, drug_fp, possible_cl_drug_combinations)
print("Number of possible cl drug combos before pseudo label based filtering: ")
print(len(cl_aug_train_dataset))
cl_aug_train_dataloader = DataLoader(cl_aug_train_dataset, batch_size=model_config["source_batch_size"], shuffle=False) # to preserve order for later subset selection


Number of possible cl drug combos before pseudo label based filtering: 
1193


In [27]:
def inference_mtl(model, cl_aug_train_dataloader):
    # forward on augmented cl data, via the patient embedder, and recist predictor
    model.eval()
    pseudo_y = []
    for idx, batch in enumerate(cl_aug_train_dataloader):
        patient_inp_emb = batch[0].to(device)
        patient_drug = batch[1].to(device)
        # print(patient_inp.shape)
        with torch.no_grad():
            # patient_inp_emb, _, _, _ = model.patient_vae(patient_inp)
        
            patient_drug_emb = model.drug_embedder(patient_drug)
        
            patient_cat = torch.cat((patient_inp_emb, patient_drug_emb), axis = 1)
        
            recist_pred = nn.Sigmoid()(model.recist_predictor(patient_cat)).view(-1, 1)
            pseudo_y.append(recist_pred)

    return torch.cat(pseudo_y)

In [28]:
pseudolabels_df = pd.DataFrame()
pseudolabels_df[["sample_id", "drug_name"]] = possible_cl_drug_combinations
pseudolabels_df

Unnamed: 0,sample_id,drug_name
0,PR-132fPs,TEMOZOLOMIDE
1,PR-L3QLdq,TEMOZOLOMIDE
2,PR-NxSV8u,TEMOZOLOMIDE
3,PR-oLPbwB,TEMOZOLOMIDE
4,PR-4ngqZx,TEMOZOLOMIDE
...,...,...
1188,PR-68VQ73,TEMOZOLOMIDE
1189,PR-BqenKD,TEMOZOLOMIDE
1190,PR-7yEowu,TEMOZOLOMIDE
1191,PR-4gQ8AD,TEMOZOLOMIDE


In [29]:
# get pseudo labels
pseudolabels = inference_mtl(mtl_model_trained, cl_aug_train_dataloader)

In [30]:
pseudolabels_df["pseudolabels"] = pseudolabels.cpu().detach().numpy()

In [31]:
pseudolabels_df

Unnamed: 0,sample_id,drug_name,pseudolabels
0,PR-132fPs,TEMOZOLOMIDE,0.142517
1,PR-L3QLdq,TEMOZOLOMIDE,0.206564
2,PR-NxSV8u,TEMOZOLOMIDE,0.112779
3,PR-oLPbwB,TEMOZOLOMIDE,0.191980
4,PR-4ngqZx,TEMOZOLOMIDE,0.117007
...,...,...,...
1188,PR-68VQ73,TEMOZOLOMIDE,0.179152
1189,PR-BqenKD,TEMOZOLOMIDE,0.141167
1190,PR-7yEowu,TEMOZOLOMIDE,0.231329
1191,PR-4gQ8AD,TEMOZOLOMIDE,0.188726


In [32]:
pseudolabels_df.describe()

Unnamed: 0,pseudolabels
count,1193.0
mean,0.153348
std,0.041275
min,0.059506
25%,0.122279
50%,0.148939
75%,0.177953
max,0.317732


In [33]:
def convert_binary(prediction, lower_threshold, upper_threshold):
    if prediction >= upper_threshold:
        return 1
    elif prediction < lower_threshold:
        return 0
    else:
        return -1

In [34]:
# threshold and select confident samples
if fold in [0, 1]:
    if drug2consider == "TEMOZOLOMIDE":
        # pseudolabels_df["pseudolabels_binary"] = pseudolabels_df["pseudolabels"].apply(lambda x: convert_binary(x, 0.1, 0.18))
        pseudolabels_df["pseudolabels_binary"] = pseudolabels_df["pseudolabels"].apply(lambda x: convert_binary(x, 0.12, 0.18))
    else:
        pseudolabels_df["pseudolabels_binary"] = pseudolabels_df["pseudolabels"].apply(lambda x: convert_binary(x, 0.1, 0.7))
else:
    if drug2consider == "CISPLATIN":
        pseudolabels_df["pseudolabels_binary"] = pseudolabels_df["pseudolabels"].apply(lambda x: convert_binary(x, 0.1, 0.6))
    elif drug2consider == "TEMOZOLOMIDE":
        pseudolabels_df["pseudolabels_binary"] = pseudolabels_df["pseudolabels"].apply(lambda x: convert_binary(x, 0.07, 0.15))
    else:
        pseudolabels_df["pseudolabels_binary"] = pseudolabels_df["pseudolabels"].apply(lambda x: convert_binary(x, 0.1, 0.7))

In [35]:
pseudolabels_df[pseudolabels_df.pseudolabels_binary != -1]["pseudolabels_binary"].value_counts()

pseudolabels_binary
1    280
0    274
Name: count, dtype: int64

In [36]:
# using data loaders to prevent execessive memory usage
class CustomCombinedDataSetLabelled(TensorDataset):
    def __init__(self, combined_df, cl_augmented_df, train_target_inputs_vae, drug_fp): # possible_combinations must only consist of samples with drug name with a fingerprint
        self.sample_df = combined_df.reset_index(drop=True)
        self.augmented_cl_df = cl_augmented_df
        self.tcga_vae_df = train_target_inputs_vae[~train_target_inputs_vae.index.duplicated(keep="first")]
        self.drug_fp = drug_fp

    def __getitem__(self, idx):
        row = self.sample_df.iloc[idx]
        sample_name = row["sample_id"]
        drug_name = row["drug_name"]
        if sample_name in self.tcga_vae_df.index: # using VAE version instead of mutation profiles
            mut_profile = self.tcga_vae_df.loc[sample_name].values
        if sample_name in self.augmented_cl_df.index:
            mut_profile = self.augmented_cl_df.loc[sample_name].values
        drug_inp = self.drug_fp.loc[drug_name].values
        response = row["recist"]
        return torch.FloatTensor(mut_profile), torch.FloatTensor(drug_inp), response

    def __len__(self):
        return len(self.sample_df)

In [37]:

# non-abstained, confident pseudo labels
confident_pseudolabels_df = pseudolabels_df[pseudolabels_df.pseudolabels_binary != -1]
confident_pseudolabels_df_idx = confident_pseudolabels_df.index # used to filter out the possible drug combinations df

confident_cl_drug_combinations_df = possible_cl_drug_combinations_df[possible_cl_drug_combinations_df.index.isin(confident_pseudolabels_df_idx)].copy()
confident_cl_drug_combinations_df["recist"] = list(confident_pseudolabels_df["pseudolabels_binary"])
print("Number of confident cl drug combinations with pseudolabels: ")
print(confident_cl_drug_combinations_df.shape)
print("Pseudo label distribution after majority vote:")
print(confident_cl_drug_combinations_df.recist.value_counts())

# combine confident CL samples with pseudolabels, with TCGA train data
combined_dataset_df = pd.concat([confident_cl_drug_combinations_df, train_target_data_merged[confident_cl_drug_combinations_df.columns]], axis=0)
combined_dataset = CustomCombinedDataSetLabelled(combined_dataset_df, cl_augmented_df, train_target_inputs_vae, drug_fp)
combined_dataloader = DataLoader(combined_dataset, batch_size=model_config["drp_batch_size"], shuffle=True, worker_init_fn = seed_worker, generator = g)



Number of confident cl drug combinations with pseudolabels: 
(554, 3)
Pseudo label distribution after majority vote:
recist
1    280
0    274
Name: count, dtype: int64


In [38]:
class DRP(nn.Module):
    def __init__(self):
        super().__init__()
        self.drug_embedder = nn.Sequential(nn.Linear(2048, 256), nn.ReLU(), nn.Linear(256, 64))
        self.recist_predictor = nn.Sequential(nn.Linear(64 * 2, 16), nn.ReLU(), nn.Linear(16, 1))

    def forward(self, patient_inp, patient_drug):
        # patient_inp is 64 dim, drugs are 2048 dim
        patient_drug_emb = self.drug_embedder(patient_drug)
        patient_cat = torch.cat((patient_inp, patient_drug_emb), axis = 1)

        # recist prediction
        recist_pred = self.recist_predictor(patient_cat)

        return recist_pred
        

In [39]:
def inference_drp_model(model, patient_val_dataloader):
    model.eval()
    y_preds = []
    y_trues = []
    for idx, batch in enumerate(patient_val_dataloader):
        with torch.no_grad():
            patient_inp = batch[0].to(device)
            patient_drug = batch[1].to(device)
            label = batch[2].to(device)
            y_preds.append(nn.Sigmoid()(model(patient_inp, patient_drug)).view(-1, 1))
            y_trues.append(label.view(-1, 1))
    return torch.cat(y_preds), torch.cat(y_trues)

In [40]:
def train_drp_model(model, train_dataloader, patient_val_dataloader, num_epochs=100, lr=1e-3):
    """
    To train vanilla baseline model
    """
    criterion = nn.BCEWithLogitsLoss()
    optim = torch.optim.Adam(model.parameters(), lr = lr)
    # training 
    val_corrs = []
    count = 0
    for i in range(num_epochs):
        model.train()
        train_losses = []
        if i > 10 and i % 10 == 0:
            lr = lr/10
            optim = torch.optim.Adam(model.parameters(), lr = lr)
        for idx, batch in enumerate(train_dataloader):
            optim.zero_grad()
            patient_inp = batch[0].to(device)
            patient_drug = batch[1].to(device)
            label = batch[2].to(device)
            y_pred = model(patient_inp, patient_drug).view(-1, 1)
            loss = criterion(y_pred, label.view(-1, 1).to(device, dtype=torch.float32))
            loss.backward()
            optim.step()
            train_losses.append(loss.item())

        y_test_pred, test_y = inference_drp_model(model, patient_val_dataloader)
        patient_corr = pearsonr(test_y.detach().cpu().numpy().reshape(-1), y_test_pred.detach().cpu().numpy().reshape(-1)).statistic + 1 # range in [0, 2]

        val_corrs.append(patient_corr)
        print(f"Epoch {i}: Training loss: {np.mean(train_losses)} |  Validation correlation: {patient_corr}")

        # wandb.log({
        #     f"{model.model_name}_train_loss": loss.detach().item(),
        #     f"validation_score": patient_corr
        # })
        # convergence based on val score
        if len(val_corrs) == 1: # first epoch
            best_val_score = patient_corr

        # save model
        if model_config["model_save_criteria"] in ["val_AUROC", "val_AUPRC", "val_corr"]: # maximise values
            if patient_corr >= best_val_score:
                best_val_score = patient_corr
                # save model
                print("Best model")
                torch.save(model.state_dict(), f"{folder_config['model_checkpoint_folder']}/{model.model_name}_{model_config['model_save_criteria']}_{model_config['experiment_id']}_{model_config['experiment_settings']}_fold{model_config['sample_id']}_tuned4{drug2consider}.pth")
                count = 0 # reset count
            else:
                count += 1 # declining performance on validation data
        else:
            print("Unsupported metric for optimising")
            return
        
        if count >= 3:
            print("Converged")
            break

        # # convergence checking based on validation correlation
        # if len(val_corrs) > 2:
        #     if val_corrs[-1] < val_corrs[-2]: # maximise correlation
        #         count += 1
        #     else:
        #         print("Best model")
        #         torch.save(model.state_dict(), f"{folder_config['model_checkpoint_folder']}/{model.model_name}_{model_config['model_save_criteria']}_{model_config['experiment_id']}_{model_config['experiment_settings']}_fold{model_config['sample_id']}.pth")
        #         count = 0
        # if len(val_corrs) == 1:
        #     torch.save(model.state_dict(), f"{folder_config['model_checkpoint_folder']}/{model.model_name}_{model_config['model_save_criteria']}_{model_config['experiment_id']}_{model_config['experiment_settings']}_fold{model_config['sample_id']}.pth")
        # if count > 3:
        #     print("Converged")
        #     break

In [41]:
target_val_vae_dataset = CustomCombinedDataSetLabelled(val_target_data_merged, cl_augmented_df, val_target_inputs_vae, drug_fp)
target_dataloader_val_vae = DataLoader(target_val_vae_dataset, batch_size=model_config["drp_batch_size"], shuffle=True, worker_init_fn = seed_worker, generator = g)

In [42]:
len(target_val_vae_dataset)

53

In [43]:
# initialise the DRP NN 
nn_drp = DRP().to(device)
nn_drp.model_name = "DRP_model"

# Train DRP model
train_drp_model(nn_drp, combined_dataloader, target_dataloader_val_vae, num_epochs=model_config["drp_epochs"], lr=1e-4)

Epoch 0: Training loss: 0.6849306225776672 |  Validation correlation: 1.1352738046749338
Best model
Epoch 1: Training loss: 0.6943975289662679 |  Validation correlation: 1.1407907399534682
Best model
Epoch 2: Training loss: 0.6952331860860189 |  Validation correlation: 1.145223064240115
Best model
Epoch 3: Training loss: 0.7091027696927389 |  Validation correlation: 1.1495237158723213
Best model
Epoch 4: Training loss: 0.6789979736010233 |  Validation correlation: 1.1531087171855072
Best model
Epoch 5: Training loss: 0.7011725505193075 |  Validation correlation: 1.156751928767962
Best model
Epoch 6: Training loss: 0.6761032144228617 |  Validation correlation: 1.1602044164118768
Best model
Epoch 7: Training loss: 0.6894575158754984 |  Validation correlation: 1.1637547811001683
Best model
Epoch 8: Training loss: 0.670441210269928 |  Validation correlation: 1.1677072728395927
Best model
Epoch 9: Training loss: 0.6933155059814453 |  Validation correlation: 1.1722914855563396
Best model
Epo

In [44]:
nn_drp_trained = DRP().to(device)
nn_drp_trained.model_name = "DRP_model"
nn_drp_trained.load_state_dict(torch.load(f"{folder_config['model_checkpoint_folder']}/{nn_drp_trained.model_name}_{model_config['model_save_criteria']}_{model_config['experiment_id']}_{model_config['experiment_settings']}_fold{model_config['sample_id']}_tuned4{drug2consider}.pth"))

<All keys matched successfully>

In [45]:
nn_drp_trained.eval()

DRP(
  (drug_embedder): Sequential(
    (0): Linear(in_features=2048, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=64, bias=True)
  )
  (recist_predictor): Sequential(
    (0): Linear(in_features=128, out_features=16, bias=True)
    (1): ReLU()
    (2): Linear(in_features=16, out_features=1, bias=True)
  )
)

In [46]:
train_target_inputs_vae

Unnamed: 0_level_0,vae_feat0,vae_feat1,vae_feat2,vae_feat3,vae_feat4,vae_feat5,vae_feat6,vae_feat7,vae_feat8,vae_feat9,...,vae_feat54,vae_feat55,vae_feat56,vae_feat57,vae_feat58,vae_feat59,vae_feat60,vae_feat61,vae_feat62,vae_feat63
sample_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
TCGA-DB-A64P,-0.219352,2.517061,-0.193680,0.221351,0.395150,-0.134675,-2.918349,5.108014,-2.359812,0.029782,...,2.111310,-2.426718,-3.463559,2.894500,3.217860,0.323152,-1.507236,-1.561500,-1.303110,2.956490
TCGA-S9-A89V,-3.428875,-0.008528,-1.819188,1.852396,-0.628296,2.462200,1.170798,2.624692,0.891387,-0.572865,...,-1.886937,-0.624494,2.970707,3.218815,1.927609,-2.655958,0.688890,-3.869405,-4.874074,-1.364088
P-0001324-T01-IM3,-1.106438,-2.314227,3.557765,0.184880,1.189587,-1.151612,0.538875,-4.448138,-1.198766,0.421414,...,-1.131532,-5.738163,-3.353961,-0.501270,-3.265127,-2.054276,1.511532,3.272763,-0.729759,1.041497
TCGA-S9-A6U8,-2.193700,-2.502175,1.185150,0.902696,2.849489,2.877264,0.662432,-0.440605,3.555880,-4.229028,...,0.845808,-0.134975,-1.616295,1.131865,-1.768301,-0.962572,0.297010,-0.605425,-1.986692,3.704942
TCGA-CN-4731,-1.953081,0.621742,0.856724,0.775729,1.064387,-2.186353,-0.084217,2.687279,-4.898835,-0.854763,...,4.714758,0.444722,1.375533,-3.258552,3.249691,2.548977,-0.548509,2.286129,0.750298,-1.068531
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
s_DS_bkm_008_T,-3.186148,-1.952735,0.761529,1.318602,-0.996578,-1.446638,1.205739,2.742002,-0.140069,-1.261682,...,0.094275,1.562647,-4.283792,2.880412,-1.333724,2.645974,1.869026,1.431277,1.517711,-2.709214
TCGA-GN-A8LK,-1.477847,2.011506,2.229980,0.311986,2.401161,-0.266892,1.614032,3.796042,2.365186,3.143873,...,-0.762206,6.970009,0.331667,-2.466131,2.497218,1.668934,-0.541309,-3.674359,-2.692982,3.219409
TCGA-VS-A8EJ,-4.536828,-1.235438,-1.589547,1.477945,-0.506021,3.648753,-0.122534,0.614050,1.355840,-3.233706,...,2.861136,0.833473,0.194434,1.164148,0.171118,0.882531,2.083854,0.381664,0.364062,-2.319056
P-0002719-T01-IM3,1.216025,-1.669143,2.676366,2.229857,-0.272450,-2.233571,0.934828,0.265331,-7.029023,3.315549,...,-0.245005,2.844542,0.836247,-5.291439,2.137110,-1.807111,-1.278555,1.867993,2.062442,-0.477907


In [47]:
# target test
target_test_vae_dataset = CustomCombinedDataSetLabelled(test_target_data_merged, cl_augmented_df, test_target_inputs_vae, drug_fp)
target_dataloader_test_vae = DataLoader(target_test_vae_dataset, batch_size=model_config["drp_batch_size"], shuffle=False)

In [48]:
len(target_test_vae_dataset)

115

In [49]:
y_test_pred, test_y = inference_drp_model(nn_drp_trained, target_dataloader_test_vae)

In [50]:
res_df = pd.DataFrame()
res_df["y_pred"] = y_test_pred.cpu().detach().numpy().reshape(-1)
res_df["y_true"] = test_y.cpu().detach().numpy().reshape(-1)

In [51]:
res_df

Unnamed: 0,y_pred,y_true
0,0.441515,0
1,0.471793,0
2,0.453884,0
3,0.451931,0
4,0.413706,0
...,...,...
110,0.428446,0
111,0.529848,0
112,0.480600,1
113,0.412201,0


In [52]:
res_df = pd.concat([test_target_data_merged[["sample_id", "drug_name", "recist"]], res_df], axis = 1)
res_df

Unnamed: 0,sample_id,drug_name,recist,y_pred,y_true
0,s_DS_bkm_001_T,BUPARLISIB,0,0.441515,0
1,s_DS_bkm_006_T,BUPARLISIB,0,0.471793,0
2,s_DS_bkm_013_T,BUPARLISIB,0,0.453884,0
3,s_DS_bkm_020_T,BUPARLISIB,0,0.451931,0
4,s_DS_bkm_021_T,BUPARLISIB,0,0.413706,0
...,...,...,...,...,...
110,TCGA-S9-A6WM,TEMOZOLOMIDE,0,0.428446,0
111,TCGA-S9-A6WN,TEMOZOLOMIDE,0,0.529848,0
112,TCGA-DB-A4XD,TEMOZOLOMIDE,1,0.480600,1
113,TCGA-FG-A4MW,TEMOZOLOMIDE,0,0.412201,0


In [53]:
res_df_drug = res_df[res_df["drug_name"] == drug2consider]
res_df_drug

Unnamed: 0,sample_id,drug_name,recist,y_pred,y_true
85,TCGA-HW-A5KL,TEMOZOLOMIDE,0,0.429723,0
86,TCGA-DU-A5TS,TEMOZOLOMIDE,0,0.51722,0
87,TCGA-DU-A7TA,TEMOZOLOMIDE,0,0.453721,0
88,TCGA-QH-A6X8,TEMOZOLOMIDE,0,0.497031,0
89,TCGA-S9-A7R4,TEMOZOLOMIDE,0,0.485207,0
90,TCGA-E1-A7YW,TEMOZOLOMIDE,0,0.507778,0
91,TCGA-FG-7638,TEMOZOLOMIDE,0,0.315213,0
92,TCGA-DB-A64V,TEMOZOLOMIDE,0,0.425121,0
93,TCGA-VW-A7QS,TEMOZOLOMIDE,1,0.478257,1
94,TCGA-S9-A6TY,TEMOZOLOMIDE,0,0.486741,0


In [54]:
from sklearn.metrics import roc_auc_score, average_precision_score

In [55]:
roc_auc_score(res_df_drug["y_true"], res_df_drug["y_pred"])

0.6923076923076923

In [56]:
average_precision_score(res_df_drug["y_true"], res_df_drug["y_pred"])

0.2541208791208791

In [58]:
res_df_drug.to_csv(f"/data/ajayago/papers_data/DiffDRP_v7/run_files/saved_model_annotated_mutations/prediction_patients_val_corr_2A_ALL_fold{fold}_tuned4{drug2consider}.csv")