Remove MTL in downstream, i.e. STL with patient data.

Remove attention in pretraining, and remove transformer inputs and diffusion, use `ablation_noattn_notransformer_nodiffusion` folder for pretrained

In [1]:
import pandas as pd
import numpy as np
import pickle
import torch
import torch.nn as nn
import math

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torch.utils.data import TensorDataset, DataLoader
import yaml
import pprint
import os
import wandb
import sys
import random
from scipy.stats import mode, pearsonr
import pickle
import itertools
import sys
sys.path.append("../")

from src.gaussian_multinomial_diffusion import GaussianMultinomialDiffusion
from src.modules import MLPDiffusion
from src.vae_model import vae
from src.loss_functions import get_kld_loss, coral
from model_definition import *

In [3]:
fold = 2

In [4]:
# global variables

CONFIG_PATH = f"../experiment_settings_yaml/ablation/model_config_2A_annotated_mutations_v7_fold{fold}_noattn_notransformer_nodiffusion.yaml" # model config path
pretty_print = pprint.PrettyPrinter()
print(f"Loading config from {CONFIG_PATH}")
config = yaml.safe_load(open(CONFIG_PATH))
model_config = config["model_hyperparams"]
folder_config = config["folder_config"]
wandb_config = config["wandb_config"]
wandb_config["project_name"] = wandb_config["project_name"] + f"-{model_config['experiment_id']}-{model_config['experiment_settings']}-fold{model_config['sample_id']}" # updates wandb project name for ease of monitoring and logging.
device = torch.device(f"cuda:{model_config['device']}" if torch.cuda.is_available() else "cpu")
genes_324 = list(pd.read_csv(f"{folder_config['data_folder']}/raw/metadata/gene2ind.txt", header=None)[0])
drug_fp = pd.read_csv(f"{folder_config['data_folder']}/raw/metadata/drug_morgan_fingerprints.csv", index_col=0)
suffixes = ["_piu_max", "_piu_sum", "_piu_mean", "_piu_count",
            "_lu_max", "_lu_sum", "_lu_mean", "_lu_count",
            "_ncu_max", "_ncu_sum", "_ncu_mean", '_ncu_count',
            "_pathogenic_max", "_pathogenic_sum", "_pathogenic_mean", "_pathogenic_count",
            "_vus_max", "_vus_sum", "_vus_mean", "_vus_count",
            "_benign_max", "_benign_sum", "_benign_mean", "_benign_count"
           ]
genes_7776 = []
for s in suffixes:
    for g in list(pd.read_csv(f"{folder_config['data_folder']}/raw/metadata/gene2ind.txt", header=None)[0]):
        genes_7776.append(f"{g}{s}")

# setting up wandb
os.environ["WANDB_CACHE_DIR"] = wandb_config["wandb_cache_dir"]
os.environ["WANDB_DIR"] = wandb_config["wandb_cache_dir"]
wandb.login(key=wandb_config["api_key"])

# seeding
torch.manual_seed(model_config["seed"])
random.seed(model_config["seed"])
np.random.seed(model_config["seed"])
# reproducibility in data loading - https://pytorch.org/docs/stable/notes/randomness.html
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(model_config["seed"])



Loading config from ../experiment_settings_yaml/ablation/model_config_2A_annotated_mutations_v7_fold1_noattn_notransformer_nodiffusion.yaml


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjayagopalaishwarya[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/ajayago/.netrc


<torch._C.Generator at 0x7f34659b3e70>

In [5]:
pretrained_folder = "/data/ajayago/papers_data/DiffDRP_v7/run_files/saved_model_annotated_mutations/ablation/ablation_noattn_notransformer_nodiffusion//"

In [6]:
# pass samples through the VAE and DDPM network, till just before VAE decoder
def vae_decoder_input(df, vae):
    """
    Takes input df, pretrained vae and diffusion model as inputs, runs forward pass till VAE decoder
    """
    with torch.no_grad():
        batch = torch.tensor(df.values) # convert to torch tensor
        inp_vae = batch.to(device, dtype=torch.float32)
        inp, mu, logvar, _ = vae(inp_vae) # From VAE encoder + reparameterization

    return inp.detach().cpu().numpy()

def load_pretrained_models():
    pretrained = torch.load(f"{folder_config['model_checkpoint_folder']}/best_pretrained_validation_loss_{model_config['model_save_criteria']}_{model_config['experiment_id']}_{model_config['experiment_settings']}_fold{model_config['sample_id']}.pth")
    is_real = True if model_config["input_data_type"] == "binary_mutations" else False
    # patients
    patient_vae = vae(input_dim=model_config["feature_num"], k_list=model_config["patient_vae_k_list"], actf_list=model_config["patient_vae_actf_list"], is_real=is_real).to(device)
    patient_vae.load_state_dict(pretrained["patient_vae_conditioned"])
    # cell lines
    cl_vae = vae(input_dim=model_config["feature_num"], k_list=model_config["cl_vae_k_list"], actf_list=model_config["cl_vae_actf_list"], is_real=is_real).to(device)
    cl_vae.load_state_dict(pretrained["cl_vae_conditioned"])
    return cl_vae, patient_vae

def load_datasets(sample_id):
    """
    Takes sample_id as input, loads source and target train, validation and test splits (predefined files from Processing folder).
    """
    data_dir = folder_config["data_folder"] + "input_types/"
    # navigate based on input type
    if model_config["input_data_type"] == "binary_mutations":
        data_dir = data_dir + "raw_mutations/"
        features2select = genes_324  # inclusive of Morgan drug fingerprints of 2048 dim
    elif model_config["input_data_type"] == "annotated_mutations":
        data_dir = data_dir + "annotated_mutations/"
        features2select = genes_7776  # inclusive of Morgan drug fingerprints of 2048 dim
    else:
        print("Unsupported input type!")
        return
    
    # navigate based on experiment id
    if model_config["experiment_id"] == "1A":
        data_dir = data_dir + "Experiment1/SettingA/"
    elif model_config["experiment_id"] == "1B":
        data_dir = data_dir + "Experiment1/SettingB/"
    elif model_config["experiment_id"] == "2A":
        data_dir = data_dir + "Experiment2/SettingA/"
    elif model_config["experiment_id"] == "2B":
        data_dir = data_dir + "Experiment2/SettingB/"
    else:
        print("Unsupported experiment ID!")
        return
    
    # load the fold based on sample_id - Note: cell lines have only 1 fold (fold 0)
    with open(f"{data_dir}/cell_lines_fold0_processed.pkl", "rb") as f:
        source_data = pickle.load(f)

    with open(f"{data_dir}/patients_fold{sample_id}_processed.pkl", "rb") as f:
        target_data = pickle.load(f)

    # load pretrained TCGA VAE and diffusion models
    # pass data points through patient DDPM and get the input to VAE decoder for DRP
    cl_vae, patient_vae = load_pretrained_models()
    
    # select data based on experiment settings 
    # Can be CISPLATIN, PACLITAXEL, FLUOROURACIL, SORAFENIB for 1A, CISPLATIN, TCGA-CESC; CISPLATIN, TCGA-HNSC; PACLITAXEL, TCGA-BRCA; FLUOROURACIL, TCGA-STAD for 1B
    # ALL for 2A, TCGA-BRCA, TCGA-CESC, TCGA-HNSC, TCGA-STAD for 2B
    if model_config["experiment_id"] in ["1A", "2B"]:
        setting = model_config["experiment_settings"]
        train_source_data, val_source_data, test_source_data = source_data["train"][setting], source_data["val"][setting], source_data["test"][setting]
        train_target_data, val_target_data, test_target_data = target_data["train"][setting], target_data["val"][setting], target_data["test"][setting]
    elif model_config["experiment_id"] == "1B":
        setting = (model_config["experiment_settings"].split(", ")[0], model_config["experiment_settings"].split(", ")[1], "TCGA")
        train_source_data, val_source_data, test_source_data = source_data["train"][setting], source_data["val"][setting], source_data["test"][setting]
        train_target_data, val_target_data, test_target_data = target_data["train"][setting], target_data["val"][setting], target_data["test"][setting]
    elif model_config["experiment_id"] == "2A":
        train_source_data, val_source_data, test_source_data = source_data["train"], source_data["val"], source_data["test"]
        train_target_data, val_target_data, test_target_data = target_data["train"], target_data["val"], target_data["test"]
    else:
        print("Unsupported experiment settings and ID")
        return
    
    # merge dataframes with drug Morgan fingprint dataframes
    train_source_data_merged = train_source_data.merge(drug_fp, left_on="drug_name", right_on=drug_fp.index)
    val_source_data_merged = val_source_data.merge(drug_fp, left_on="drug_name", right_on=drug_fp.index)
    test_source_data_merged = test_source_data.merge(drug_fp, left_on="drug_name", right_on=drug_fp.index)

    train_target_data_merged = train_target_data.merge(drug_fp, left_on="drug_name", right_on=drug_fp.index)
    val_target_data_merged = val_target_data.merge(drug_fp, left_on="drug_name", right_on=drug_fp.index)
    test_target_data_merged = test_target_data.merge(drug_fp, left_on="drug_name", right_on=drug_fp.index)

    assert train_source_data_merged.shape[0] == train_source_data.shape[0], "Train source data loss after merge!"
    assert val_source_data_merged.shape[0] == val_source_data.shape[0], "Val source data loss after merge!"
    assert test_source_data_merged.shape[0] == test_source_data.shape[0], "Test source data loss after merge!"
    assert train_target_data_merged.shape[0] == train_target_data.shape[0], "Train target data loss after merge!"
    assert val_target_data_merged.shape[0] == val_target_data.shape[0], "Val target data loss after merge!"
    assert test_target_data_merged.shape[0] == test_target_data.shape[0], "Test target data loss after merge!"

    # separate out into input, drug and labels
    train_source_inputs, val_source_inputs, test_source_inputs = train_source_data_merged[features2select], val_source_data_merged[features2select], test_source_data_merged[features2select]
    # pass cl samples through cl diff model and vae
    train_source_inputs_vae = pd.DataFrame(vae_decoder_input(train_source_inputs, cl_vae), columns=[f"vae_feat{i}" for i in range(model_config["cl_vae_k_list"][-1]//2)], index=train_source_data_merged.index)
    val_source_inputs_vae = pd.DataFrame(vae_decoder_input(val_source_inputs, cl_vae), columns=[f"vae_feat{i}" for i in range(model_config["cl_vae_k_list"][-1]//2)], index=val_source_data_merged.index)
    test_source_inputs_vae = pd.DataFrame(vae_decoder_input(test_source_inputs, cl_vae), columns=[f"vae_feat{i}" for i in range(model_config["cl_vae_k_list"][-1]//2)], index=test_source_data_merged.index)
    train_source_drugs, val_source_drugs, test_source_drugs = train_source_data_merged[[str(i) for i in range(0, 2048)]].values, val_source_data_merged[[str(i) for i in range(0, 2048)]].values, test_source_data_merged[[str(i) for i in range(0, 2048)]].values
    train_source_labels, val_source_labels, test_source_labels = train_source_data_merged["auc"].values, val_source_data_merged["auc"].values, test_source_data_merged["auc"].values

    train_target_inputs, val_target_inputs, test_target_inputs = train_target_data_merged[features2select], val_target_data_merged[features2select], test_target_data_merged[features2select]
    # pass patient samples through tcga diff model and vae
    train_target_inputs_vae = pd.DataFrame(vae_decoder_input(train_target_inputs, patient_vae), columns=[f"vae_feat{i}" for i in range(model_config["patient_vae_k_list"][-1]//2)], index=train_target_data_merged.sample_id)
    val_target_inputs_vae = pd.DataFrame(vae_decoder_input(val_target_inputs, patient_vae), columns=[f"vae_feat{i}" for i in range(model_config["patient_vae_k_list"][-1]//2)], index=val_target_data_merged.sample_id)
    test_target_inputs_vae = pd.DataFrame(vae_decoder_input(test_target_inputs, patient_vae), columns=[f"vae_feat{i}" for i in range(model_config["patient_vae_k_list"][-1]//2)], index=test_target_data_merged.sample_id)

    train_target_drugs, val_target_drugs, test_target_drugs = train_target_data_merged[[str(i) for i in range(0, 2048)]].values, val_target_data_merged[[str(i) for i in range(0, 2048)]].values, test_target_data_merged[[str(i) for i in range(0, 2048)]].values
    train_target_labels, val_target_labels, test_target_labels = train_target_data_merged["recist"].values, val_target_data_merged["recist"].values, test_target_data_merged["recist"].values

    return train_source_inputs_vae, train_source_drugs, train_source_labels, val_source_inputs_vae, val_source_drugs, val_source_labels, test_source_inputs_vae, test_source_drugs, test_source_labels, train_target_inputs_vae, train_target_drugs, train_target_labels, val_target_inputs_vae, val_target_drugs, val_target_labels, test_target_inputs_vae, test_target_drugs, test_target_labels, train_target_data_merged, val_target_data_merged, test_target_data_merged, train_source_data_merged, val_source_data_merged, test_source_data_merged

    # pass # needs to return (train_source_data, train_source_labels, val_source_data, val_source_labels, test_source_data, test_source_labels), (train_target_data, train_target_labels, val_target_data, val_target_labels, test_target_data, test_target_labels)
    #  Dummy data
    # train_source_data, val_source_data, test_source_data = np.random.rand(32, 2048 + 4), np.random.rand(10, 2048 + 4), np.random.rand(5, 2048 + 4)
    # train_source_labels, val_source_labels, test_source_labels = np.random.randint(2, size=32), np.random.randint(2, size=10), np.random.randint(2, size=5)
    # train_target_data, val_target_data, test_target_data = np.random.rand(32, 2048 + 4), np.random.rand(10, 2048 + 4), np.random.rand(3, 2048 + 4)
    # train_target_labels, val_target_labels, test_target_labels = np.random.randint(2, size=32), np.random.randint(2, size=10), np.random.randint(2, size=3)
    # return train_source_data, train_source_labels, val_source_data, val_source_labels, test_source_data, test_source_labels, train_target_data, train_target_labels, val_target_data, val_target_labels, test_target_data, test_target_labels

def load_augmented_cl_dataset(sample_id):
    augmented_cl_df = pd.read_csv(f"{folder_config['model_checkpoint_folder']}/augmented_cl_clconditioned_uda_v2_vaeinput_{model_config['model_save_criteria']}_{model_config['experiment_id']}_{model_config['experiment_settings']}_fold{model_config['sample_id']}.csv", index_col=0)
    print(f"Loaded augmented CL data: {augmented_cl_df.shape}")
    return augmented_cl_df

In [7]:
cl_vae, patient_vae = load_pretrained_models()

U: encoder 
Sequential(
  (enc-0): Linear(in_features=7776, out_features=512, bias=True)
  (act-0): Tanh()
  (enc-1): Linear(in_features=512, out_features=128, bias=True)
  (act-1): ReLU()
)
#
mu_layer: 
Linear(in_features=128, out_features=64, bias=True)
#
sigma_layer: 
Linear(in_features=128, out_features=64, bias=True)
#
U: decoder 
Sequential(
  (-dec-0): Linear(in_features=64, out_features=128, bias=True)
  (-act-0): Tanh()
  (dec-0): Linear(in_features=128, out_features=512, bias=True)
  (act-0): Tanh()
  (dec-1): Linear(in_features=512, out_features=7776, bias=True)
  (act-1): Sigmoid()
)
U: encoder 
Sequential(
  (enc-0): Linear(in_features=7776, out_features=1024, bias=True)
  (act-0): Tanh()
  (enc-1): Linear(in_features=1024, out_features=128, bias=True)
  (act-1): Tanh()
)
#
mu_layer: 
Linear(in_features=128, out_features=64, bias=True)
#
sigma_layer: 
Linear(in_features=128, out_features=64, bias=True)
#
U: decoder 
Sequential(
  (-dec-0): Linear(in_features=64, out_featur

In [8]:
train_source_inputs_vae, train_source_drugs, train_source_labels, val_source_inputs_vae, val_source_drugs, val_source_labels, test_source_inputs_vae, test_source_drugs, test_source_labels, train_target_inputs_vae, train_target_drugs, train_target_labels, val_target_inputs_vae, val_target_drugs, val_target_labels, test_target_inputs_vae, test_target_drugs, test_target_labels, train_target_data_merged, val_target_data_merged, test_target_data_merged, train_source_data_merged, val_source_data_merged, test_source_data_merged = load_datasets(model_config["sample_id"])

U: encoder 
Sequential(
  (enc-0): Linear(in_features=7776, out_features=512, bias=True)
  (act-0): Tanh()
  (enc-1): Linear(in_features=512, out_features=128, bias=True)
  (act-1): ReLU()
)
#
mu_layer: 
Linear(in_features=128, out_features=64, bias=True)
#
sigma_layer: 
Linear(in_features=128, out_features=64, bias=True)
#
U: decoder 
Sequential(
  (-dec-0): Linear(in_features=64, out_features=128, bias=True)
  (-act-0): Tanh()
  (dec-0): Linear(in_features=128, out_features=512, bias=True)
  (act-0): Tanh()
  (dec-1): Linear(in_features=512, out_features=7776, bias=True)
  (act-1): Sigmoid()
)
U: encoder 
Sequential(
  (enc-0): Linear(in_features=7776, out_features=1024, bias=True)
  (act-0): Tanh()
  (enc-1): Linear(in_features=1024, out_features=128, bias=True)
  (act-1): Tanh()
)
#
mu_layer: 
Linear(in_features=128, out_features=64, bias=True)
#
sigma_layer: 
Linear(in_features=128, out_features=64, bias=True)
#
U: decoder 
Sequential(
  (-dec-0): Linear(in_features=64, out_featur

In [9]:
train_source_data_merged

Unnamed: 0,sample_id,drug_name,auc,ic50,drug_category,response_label,ABL1_piu_max,ACVR1B_piu_max,AKT1_piu_max,AKT2_piu_max,...,2038,2039,2040,2041,2042,2043,2044,2045,2046,2047
0,PR-132fPs,DOCETAXEL,0.191876,-4.662091,1,1,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
1,PR-L3QLdq,ELEPHANTIN,0.940458,5.730421,3,0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
2,PR-NxSV8u,MITOXANTRONE,0.921925,4.070582,1,0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
3,PR-oLPbwB,DACTINOMYCIN,0.179515,-6.588337,1,1,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
4,PR-4ngqZx,CCT007093,0.989986,3.724712,3,0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
156436,PR-M4505H,PFI-1,0.919051,3.534174,3,1,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
156437,PR-Bz57NU,NILOTINIB,0.995489,4.073733,1,0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
156438,PR-6SyWYo,SAPITINIB,0.492491,-1.567439,2,1,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
156439,PR-wGySam,TASELISIB,0.901939,2.716776,2,0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0


In [10]:
train_source_data_merged[["sample_id", "drug_name", "auc"]]

Unnamed: 0,sample_id,drug_name,auc
0,PR-132fPs,DOCETAXEL,0.191876
1,PR-L3QLdq,ELEPHANTIN,0.940458
2,PR-NxSV8u,MITOXANTRONE,0.921925
3,PR-oLPbwB,DACTINOMYCIN,0.179515
4,PR-4ngqZx,CCT007093,0.989986
...,...,...,...
156436,PR-M4505H,PFI-1,0.919051
156437,PR-Bz57NU,NILOTINIB,0.995489
156438,PR-6SyWYo,SAPITINIB,0.492491
156439,PR-wGySam,TASELISIB,0.901939


In [11]:
train_target_data_merged

Unnamed: 0,sample_id,drug_name,recist,mappedProject,dataset_name,ABL1_piu_max,ACVR1B_piu_max,AKT1_piu_max,AKT2_piu_max,AKT3_piu_max,...,2038,2039,2040,2041,2042,2043,2044,2045,2046,2047
0,TCGA-FD-A6TC,GEMCITABINE,1,TCGA-BLCA,TCGA,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,1,0,0,0,0,0
1,TCGA-S9-A6TS,CARMUSTINE,0,TCGA-LGG,TCGA,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
2,TCGA-VR-A8EQ,FLUOROURACIL,1,TCGA-ESCA,TCGA,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
3,s_DS_bkm_034_T,BUPARLISIB,0,TCGA-BRCA,CBIO_brca_mskcc_2019,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
4,TCGA-YU-A90Q,CARBOPLATIN,1,TCGA-TGCT,TCGA,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
483,s_DS_bkm_013_T,BUPARLISIB,0,TCGA-BRCA,CBIO_brca_mskcc_2019,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
484,TCGA-GN-A8LK,CARBOPLATIN,0,TCGA-SKCM,TCGA,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
485,TCGA-VS-A8EJ,CISPLATIN,0,TCGA-CESC,TCGA,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
486,P-0021780-T01-IM6,SORAFENIB,0,TCGA-LIHC,CBIO_hcc_mskimpact_2018,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,1,0,0,0,0


In [12]:
train_target_data_merged[genes_7776]

Unnamed: 0,ABL1_piu_max,ACVR1B_piu_max,AKT1_piu_max,AKT2_piu_max,AKT3_piu_max,ALK_piu_max,ALOX12B_piu_max,APC_piu_max,AR_piu_max,ARAF_piu_max,...,U2AF1_benign_count,VEGFA_benign_count,VHL_benign_count,WHSC1_benign_count,WHSC1L1_benign_count,WT1_benign_count,XPO1_benign_count,XRCC2_benign_count,ZNF217_benign_count,ZNF703_benign_count
0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.000000,0.000000,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.000000,0.882353,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.000000,0.000000,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.000000,0.000000,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.000000,0.000000,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
483,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.000000,0.000000,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
484,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.882353,0.000000,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
485,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.000000,0.000000,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
486,0.0,0.0,0.0,0.0,0.0,0.294118,0.0,0.000000,0.000000,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [13]:
train_target_data_merged[["sample_id", "drug_name", "recist"]]

Unnamed: 0,sample_id,drug_name,recist
0,TCGA-FD-A6TC,GEMCITABINE,1
1,TCGA-S9-A6TS,CARMUSTINE,0
2,TCGA-VR-A8EQ,FLUOROURACIL,1
3,s_DS_bkm_034_T,BUPARLISIB,0
4,TCGA-YU-A90Q,CARBOPLATIN,1
...,...,...,...
483,s_DS_bkm_013_T,BUPARLISIB,0
484,TCGA-GN-A8LK,CARBOPLATIN,0
485,TCGA-VS-A8EJ,CISPLATIN,0
486,P-0021780-T01-IM6,SORAFENIB,0


In [14]:
# create datasets
# Cell Lines
source_dataset_train = TensorDataset(torch.FloatTensor(train_source_data_merged[genes_7776].values), torch.FloatTensor(train_source_data_merged[[str(i) for i in range(0, 2048)]].values), torch.FloatTensor(train_source_data_merged["auc"].values))
source_dataset_val = TensorDataset(torch.FloatTensor(val_source_data_merged[genes_7776].values), torch.FloatTensor(val_source_data_merged[[str(i) for i in range(0, 2048)]].values), torch.FloatTensor(val_source_data_merged["auc"].values))
source_dataset_test = TensorDataset(torch.FloatTensor(test_source_data_merged[genes_7776].values), torch.FloatTensor(test_source_data_merged[[str(i) for i in range(0, 2048)]].values), torch.FloatTensor(test_source_data_merged["auc"].values))

# Patients
target_dataset_train = TensorDataset(torch.FloatTensor(train_target_data_merged[genes_7776].values), torch.FloatTensor(train_target_data_merged[[str(i) for i in range(0, 2048)]].values), torch.FloatTensor(train_target_data_merged["recist"].values))
target_dataset_val = TensorDataset(torch.FloatTensor(val_target_data_merged[genes_7776].values), torch.FloatTensor(val_target_data_merged[[str(i) for i in range(0, 2048)]].values), torch.FloatTensor(val_target_data_merged["recist"].values))
target_dataset_test = TensorDataset(torch.FloatTensor(test_target_data_merged[genes_7776].values), torch.FloatTensor(test_target_data_merged[[str(i) for i in range(0, 2048)]].values), torch.FloatTensor(test_target_data_merged["recist"].values))


In [15]:
# data loaders
source_dataloader_train = DataLoader(source_dataset_train, batch_size = 512, shuffle = True, worker_init_fn = seed_worker, generator = g)
target_dataloader_train = DataLoader(target_dataset_train, batch_size = 512, shuffle = True, worker_init_fn = seed_worker, generator = g)

source_dataloader_val = DataLoader(source_dataset_val, batch_size = 512, shuffle = False, worker_init_fn = seed_worker, generator = g)
target_dataloader_val = DataLoader(target_dataset_val, batch_size = 512, shuffle = False, worker_init_fn = seed_worker, generator = g)

source_dataloader_test = DataLoader(source_dataset_test, batch_size = 512, shuffle = False, worker_init_fn = seed_worker, generator = g)
target_dataloader_test = DataLoader(target_dataset_test, batch_size = 512, shuffle = False, worker_init_fn = seed_worker, generator = g)



In [16]:
class STL(nn.Module):
    def __init__(self, cl_vae, patient_vae):
        super().__init__()
        # self.cl_vae = cl_vae
        self.patient_vae = patient_vae
        self.drug_embedder = nn.Sequential(nn.Linear(2048, 256), nn.ReLU(), nn.Linear(256, 64))
        # self.audrc_predictor = nn.Sequential(nn.Linear(64 * 2, 16), nn.ReLU(), nn.Linear(16, 1))
        self.recist_predictor = nn.Sequential(nn.Linear(64 * 2, 16), nn.ReLU(), nn.Linear(16, 1))

    def forward(self, cl_inp, cl_drug, patient_inp, patient_drug, audrc, recist):
        # cl_inp and patient_inp are 797 dim, both drugs are 2048 dim
        # cl_inp_emb, _, _, _ = self.cl_vae(cl_inp) # From VAE encoder + reparameterization
        patient_inp_emb, _, _, _ = self.patient_vae(patient_inp)

        # cl_drug_emb = self.drug_embedder(cl_drug)
        patient_drug_emb = self.drug_embedder(patient_drug)

        # cl_cat = torch.cat((cl_inp_emb, cl_drug_emb), axis = 1)
        patient_cat = torch.cat((patient_inp_emb, patient_drug_emb), axis = 1)

        # recist and audrc prediction
        # audrc_pred = self.audrc_predictor(cl_cat)
        recist_pred = self.recist_predictor(patient_cat)

        return patient_cat, recist_pred
        

In [17]:
# train STL
mtl_model = STL(cl_vae, patient_vae).to(device)

In [18]:
optimizer = torch.optim.Adam(params = mtl_model.parameters(), lr=1e-5)

In [19]:
def testing_loop(model, cl_val_loader, patient_val_loader):
    model.eval()
    prediction = []
    true = []
    for idx, batch in enumerate(patient_val_loader):
        patient_inp = batch[0].to(device)
        drug_inp = batch[1].to(device)
        label = batch[2].to(device)
        with torch.no_grad():
            patient_emb, _, _, _ = model.patient_vae(patient_inp)
            drug_emb = model.drug_embedder(drug_inp)
            patient_cat = torch.cat((patient_emb, drug_emb), axis = 1)
            pred = model.recist_predictor(patient_cat)
            prediction.append(pred)
            true.append(label)
    predictions = torch.cat(prediction).view(-1, 1)
    trues = torch.cat(true).view(-1, 1)
    return nn.BCEWithLogitsLoss()(predictions, trues)
            

In [20]:
# training loop
count = 0
patient_val_losses = []
for epoch in range(300):
    mtl_model.train()
    epoch_loss = []
    for idx0, batch0 in enumerate(source_dataloader_train):
        for idx1, batch1 in enumerate(target_dataloader_train):
            optimizer.zero_grad()
            cl_inp, cl_drug, patient_inp, patient_drug, audrc, recist = batch0[0].to(device), batch0[1].to(device), batch1[0].to(device), batch1[1].to(device), batch0[2].to(device), batch1[2].to(device)
            patient_cat, recist_pred = mtl_model(cl_inp, cl_drug, patient_inp, patient_drug, audrc, recist)

            # # align both
            # coral_loss = coral(cl_cat, patient_cat)

            # losses
            # audrc_loss = nn.MSELoss()(audrc_pred.view(-1, 1), audrc.view(-1, 1))
            recist_loss = nn.BCEWithLogitsLoss()(recist_pred.view(-1, 1), recist.view(-1, 1))

            # total_loss = coral_loss + audrc_loss + recist_loss
            # total_loss = audrc_loss + recist_loss
            total_loss = recist_loss
            total_loss.backward()
            optimizer.step()

            epoch_loss.append(total_loss.cpu().detach().numpy().item())

    # get val loss
    patient_val_loss = testing_loop(mtl_model, source_dataloader_val, target_dataloader_val)
    patient_val_losses.append(patient_val_loss.item())
    print(f"Epoch {epoch}: train loss = {np.mean(epoch_loss)}, val patient loss = {patient_val_loss.item()}")

    if len(patient_val_losses) ==  1:
        best_val_loss = patient_val_loss.item()

    print(f"Best val loss = {best_val_loss}")
    print(f"Current val loss = {patient_val_loss.item()}")

    if patient_val_loss.item() <= best_val_loss: # minimize val loss
        torch.save(mtl_model.state_dict(), f"{folder_config['model_checkpoint_folder']}MTL_model_fold{fold}.pth")
        best_val_loss = patient_val_loss
        print("Saved!")
        count = 0
    else:
        print("Increased count")
        count += 1

    if count >= 3:
        print("Converged")
        break
            
    

Epoch 0: train loss = 0.6820839010422526, val patient loss = 0.6803722381591797
Best val loss = 0.6803722381591797
Current val loss = 0.6803722381591797
Saved!
Epoch 1: train loss = 0.6116710215612174, val patient loss = 0.6553341746330261
Best val loss = 0.6803722381591797
Current val loss = 0.6553341746330261
Saved!
Epoch 2: train loss = 0.4863982717780506, val patient loss = 0.7100379467010498
Best val loss = 0.6553341746330261
Current val loss = 0.7100379467010498
Increased count
Epoch 3: train loss = 0.3222259510965908, val patient loss = 0.8854402899742126
Best val loss = 0.6553341746330261
Current val loss = 0.8854402899742126
Increased count
Epoch 4: train loss = 0.17480620215920842, val patient loss = 1.2623865604400635
Best val loss = 0.6553341746330261
Current val loss = 1.2623865604400635
Increased count
Converged


In [21]:
# run inference on cell line, drug pairs to get pseudolabels
mtl_model_trained = STL(cl_vae, patient_vae).to(device)

In [22]:
mtl_model_trained.load_state_dict(torch.load(f"{folder_config['model_checkpoint_folder']}MTL_model_fold{fold}.pth"))

<All keys matched successfully>

In [23]:
mtl_model_trained.eval()

STL(
  (patient_vae): vae(
    (mu_layer): Linear(in_features=128, out_features=64, bias=True)
    (sigma_layer): Linear(in_features=128, out_features=64, bias=True)
    (encoder): Sequential(
      (enc-0): Linear(in_features=7776, out_features=512, bias=True)
      (act-0): Tanh()
      (enc-1): Linear(in_features=512, out_features=128, bias=True)
      (act-1): ReLU()
    )
    (decoder): Sequential(
      (-dec-0): Linear(in_features=64, out_features=128, bias=True)
      (-act-0): Tanh()
      (dec-0): Linear(in_features=128, out_features=512, bias=True)
      (act-0): Tanh()
      (dec-1): Linear(in_features=512, out_features=7776, bias=True)
      (act-1): Sigmoid()
    )
  )
  (drug_embedder): Sequential(
    (0): Linear(in_features=2048, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=64, bias=True)
  )
  (recist_predictor): Sequential(
    (0): Linear(in_features=128, out_features=16, bias=True)
    (1): ReLU()
    (2): Linear(in_fea

In [24]:
# Load augmented cell lines and drug combos
cl_augmented_df = load_augmented_cl_dataset(model_config["sample_id"])
train_val_cell_lines = list(cl_augmented_df.index)
if model_config["experiment_id"] == "1B":
    drugs_with_fp = [model_config["experiment_settings"].split(", ")[0]] # extract out drug name
elif model_config["experiment_id"] == "1A":
    drugs_with_fp = [model_config["experiment_settings"]] # has only drug name
else: # in 2A and 2B include all available drugs with fp
    drugs_with_fp = list(drug_fp.index)
possible_cl_drug_combinations = list(itertools.product(train_val_cell_lines, drugs_with_fp))
possible_cl_drug_combinations_df = pd.DataFrame(possible_cl_drug_combinations, columns = ["sample_id", "drug_name"])

Loaded augmented CL data: (1193, 64)


In [25]:
len(possible_cl_drug_combinations_df)

571447

In [26]:
# using data loaders to prevent execessive memory usage
class CustomCellLineDataSetUnlabelled(TensorDataset):
    def __init__(self, cl_augmented_df, drug_fp, possible_combinations): # possible_combinations must only consist of samples with drug name with a fingerprint
        self.possible_combinations = possible_combinations
        self.augmented_cl_df = cl_augmented_df
        self.drug_fp = drug_fp

    def __getitem__(self, idx):
        sample_name, drug_name = self.possible_combinations[idx]
        mut_profile = self.augmented_cl_df.loc[sample_name].values
        drug_inp = self.drug_fp.loc[drug_name].values
        return torch.FloatTensor(mut_profile), torch.FloatTensor(drug_inp)

    def __len__(self):
        return len(self.possible_combinations)

In [27]:
cl_aug_train_dataset = CustomCellLineDataSetUnlabelled(cl_augmented_df, drug_fp, possible_cl_drug_combinations)
print("Number of possible cl drug combos before pseudo label based filtering: ")
print(len(cl_aug_train_dataset))
cl_aug_train_dataloader = DataLoader(cl_aug_train_dataset, batch_size=model_config["source_batch_size"], shuffle=False) # to preserve order for later subset selection


Number of possible cl drug combos before pseudo label based filtering: 
571447


In [28]:
def inference_mtl(model, cl_aug_train_dataloader):
    # forward on augmented cl data, via the patient embedder, and recist predictor
    model.eval()
    pseudo_y = []
    for idx, batch in enumerate(cl_aug_train_dataloader):
        patient_inp_emb = batch[0].to(device)
        patient_drug = batch[1].to(device)
        # print(patient_inp.shape)
        with torch.no_grad():
            # patient_inp_emb, _, _, _ = model.patient_vae(patient_inp)
        
            patient_drug_emb = model.drug_embedder(patient_drug)
        
            patient_cat = torch.cat((patient_inp_emb, patient_drug_emb), axis = 1)
        
            recist_pred = nn.Sigmoid()(model.recist_predictor(patient_cat)).view(-1, 1)
            pseudo_y.append(recist_pred)

    return torch.cat(pseudo_y)

In [29]:
pseudolabels_df = pd.DataFrame()
pseudolabels_df[["sample_id", "drug_name"]] = possible_cl_drug_combinations
pseudolabels_df

Unnamed: 0,sample_id,drug_name
0,PR-132fPs,JW-7-24-1
1,PR-132fPs,KIN001-260
2,PR-132fPs,NSC-87877
3,PR-132fPs,GNE-317
4,PR-132fPs,NAVITOCLAX
...,...,...
571442,PR-2AxAKM,SB590885
571443,PR-2AxAKM,STAUROSPORINE
571444,PR-2AxAKM,TW 37
571445,PR-2AxAKM,ULIXERTINIB


In [30]:
# get pseudo labels
pseudolabels = inference_mtl(mtl_model_trained, cl_aug_train_dataloader)

In [31]:
pseudolabels_df["pseudolabels"] = pseudolabels.cpu().detach().numpy()

In [32]:
pseudolabels_df

Unnamed: 0,sample_id,drug_name,pseudolabels
0,PR-132fPs,JW-7-24-1,0.319525
1,PR-132fPs,KIN001-260,0.370869
2,PR-132fPs,NSC-87877,0.371434
3,PR-132fPs,GNE-317,0.349483
4,PR-132fPs,NAVITOCLAX,0.260132
...,...,...,...
571442,PR-2AxAKM,SB590885,0.432868
571443,PR-2AxAKM,STAUROSPORINE,0.394447
571444,PR-2AxAKM,TW 37,0.373616
571445,PR-2AxAKM,ULIXERTINIB,0.380857


In [33]:
pseudolabels_df.describe()

Unnamed: 0,pseudolabels
count,571447.0
mean,0.346632
std,0.072632
min,0.080943
25%,0.298081
50%,0.347908
75%,0.393497
max,0.646188


In [34]:
def convert_binary(prediction, lower_threshold, upper_threshold):
    if prediction >= upper_threshold:
        return 1
    elif prediction < lower_threshold:
        return 0
    else:
        return -1

In [35]:
# threshold and select confident samples
if fold in [0, 1]:
    pseudolabels_df["pseudolabels_binary"] = pseudolabels_df["pseudolabels"].apply(lambda x: convert_binary(x, 0.1, 0.7))
else:
    pseudolabels_df["pseudolabels_binary"] = pseudolabels_df["pseudolabels"].apply(lambda x: convert_binary(x, 0.1, 0.7))

In [36]:
pseudolabels_df[pseudolabels_df.pseudolabels_binary != -1]["pseudolabels_binary"].value_counts()

pseudolabels_binary
0    63
Name: count, dtype: int64

In [37]:
# using data loaders to prevent execessive memory usage
class CustomCombinedDataSetLabelled(TensorDataset):
    def __init__(self, combined_df, cl_augmented_df, train_target_inputs_vae, drug_fp): # possible_combinations must only consist of samples with drug name with a fingerprint
        self.sample_df = combined_df.reset_index(drop=True)
        self.augmented_cl_df = cl_augmented_df
        self.tcga_vae_df = train_target_inputs_vae[~train_target_inputs_vae.index.duplicated(keep="first")]
        self.drug_fp = drug_fp

    def __getitem__(self, idx):
        row = self.sample_df.iloc[idx]
        sample_name = row["sample_id"]
        drug_name = row["drug_name"]
        if sample_name in self.tcga_vae_df.index: # using VAE version instead of mutation profiles
            mut_profile = self.tcga_vae_df.loc[sample_name].values
        if sample_name in self.augmented_cl_df.index:
            mut_profile = self.augmented_cl_df.loc[sample_name].values
        drug_inp = self.drug_fp.loc[drug_name].values
        response = row["recist"]
        return torch.FloatTensor(mut_profile), torch.FloatTensor(drug_inp), response

    def __len__(self):
        return len(self.sample_df)

In [38]:

# non-abstained, confident pseudo labels
confident_pseudolabels_df = pseudolabels_df[pseudolabels_df.pseudolabels_binary != -1]
confident_pseudolabels_df_idx = confident_pseudolabels_df.index # used to filter out the possible drug combinations df

confident_cl_drug_combinations_df = possible_cl_drug_combinations_df[possible_cl_drug_combinations_df.index.isin(confident_pseudolabels_df_idx)].copy()
confident_cl_drug_combinations_df["recist"] = list(confident_pseudolabels_df["pseudolabels_binary"])
print("Number of confident cl drug combinations with pseudolabels: ")
print(confident_cl_drug_combinations_df.shape)
print("Pseudo label distribution after majority vote:")
print(confident_cl_drug_combinations_df.recist.value_counts())

# combine confident CL samples with pseudolabels, with TCGA train data
combined_dataset_df = pd.concat([confident_cl_drug_combinations_df, train_target_data_merged[confident_cl_drug_combinations_df.columns]], axis=0)
combined_dataset = CustomCombinedDataSetLabelled(combined_dataset_df, cl_augmented_df, train_target_inputs_vae, drug_fp)
combined_dataloader = DataLoader(combined_dataset, batch_size=model_config["drp_batch_size"], shuffle=True, worker_init_fn = seed_worker, generator = g)



Number of confident cl drug combinations with pseudolabels: 
(63, 3)
Pseudo label distribution after majority vote:
recist
0    63
Name: count, dtype: int64


In [39]:
class DRP(nn.Module):
    def __init__(self):
        super().__init__()
        self.drug_embedder = nn.Sequential(nn.Linear(2048, 256), nn.ReLU(), nn.Linear(256, 64))
        self.recist_predictor = nn.Sequential(nn.Linear(64 * 2, 16), nn.ReLU(), nn.Linear(16, 1))

    def forward(self, patient_inp, patient_drug):
        # patient_inp is 64 dim, drugs are 2048 dim
        patient_drug_emb = self.drug_embedder(patient_drug)
        patient_cat = torch.cat((patient_inp, patient_drug_emb), axis = 1)

        # recist prediction
        recist_pred = self.recist_predictor(patient_cat)

        return recist_pred
        

In [40]:
def inference_drp_model(model, patient_val_dataloader):
    model.eval()
    y_preds = []
    y_trues = []
    for idx, batch in enumerate(patient_val_dataloader):
        with torch.no_grad():
            patient_inp = batch[0].to(device)
            patient_drug = batch[1].to(device)
            label = batch[2].to(device)
            y_preds.append(nn.Sigmoid()(model(patient_inp, patient_drug)).view(-1, 1))
            y_trues.append(label.view(-1, 1))
    return torch.cat(y_preds), torch.cat(y_trues)

In [41]:
def train_drp_model(model, train_dataloader, patient_val_dataloader, num_epochs=100, lr=1e-3):
    """
    To train vanilla baseline model
    """
    criterion = nn.BCEWithLogitsLoss()
    optim = torch.optim.Adam(model.parameters(), lr = lr)
    # training 
    val_corrs = []
    count = 0
    for i in range(num_epochs):
        model.train()
        train_losses = []
        if i > 10 and i % 10 == 0:
            lr = lr/10
            optim = torch.optim.Adam(model.parameters(), lr = lr)
        for idx, batch in enumerate(train_dataloader):
            optim.zero_grad()
            patient_inp = batch[0].to(device)
            patient_drug = batch[1].to(device)
            label = batch[2].to(device)
            y_pred = model(patient_inp, patient_drug).view(-1, 1)
            loss = criterion(y_pred, label.view(-1, 1).to(device, dtype=torch.float32))
            loss.backward()
            optim.step()
            train_losses.append(loss.item())

        y_test_pred, test_y = inference_drp_model(model, patient_val_dataloader)
        patient_corr = pearsonr(test_y.detach().cpu().numpy().reshape(-1), y_test_pred.detach().cpu().numpy().reshape(-1)).statistic + 1 # range in [0, 2]

        val_corrs.append(patient_corr)
        print(f"Epoch {i}: Training loss: {np.mean(train_losses)} |  Validation correlation: {patient_corr}")

        # wandb.log({
        #     f"{model.model_name}_train_loss": loss.detach().item(),
        #     f"validation_score": patient_corr
        # })
        # convergence based on val score
        if len(val_corrs) == 1: # first epoch
            best_val_score = patient_corr

        # save model
        if model_config["model_save_criteria"] in ["val_AUROC", "val_AUPRC", "val_corr"]: # maximise values
            if patient_corr >= best_val_score:
                best_val_score = patient_corr
                # save model
                print("Best model")
                torch.save(model.state_dict(), f"{folder_config['model_checkpoint_folder']}/{model.model_name}_{model_config['model_save_criteria']}_{model_config['experiment_id']}_{model_config['experiment_settings']}_fold{model_config['sample_id']}.pth")
                count = 0 # reset count
            else:
                count += 1 # declining performance on validation data
        else:
            print("Unsupported metric for optimising")
            return
        
        if count >= 3:
            print("Converged")
            break

        # # convergence checking based on validation correlation
        # if len(val_corrs) > 2:
        #     if val_corrs[-1] < val_corrs[-2]: # maximise correlation
        #         count += 1
        #     else:
        #         print("Best model")
        #         torch.save(model.state_dict(), f"{folder_config['model_checkpoint_folder']}/{model.model_name}_{model_config['model_save_criteria']}_{model_config['experiment_id']}_{model_config['experiment_settings']}_fold{model_config['sample_id']}.pth")
        #         count = 0
        # if len(val_corrs) == 1:
        #     torch.save(model.state_dict(), f"{folder_config['model_checkpoint_folder']}/{model.model_name}_{model_config['model_save_criteria']}_{model_config['experiment_id']}_{model_config['experiment_settings']}_fold{model_config['sample_id']}.pth")
        # if count > 3:
        #     print("Converged")
        #     break

In [42]:
target_val_vae_dataset = CustomCombinedDataSetLabelled(val_target_data_merged, cl_augmented_df, val_target_inputs_vae, drug_fp)
target_dataloader_val_vae = DataLoader(target_val_vae_dataset, batch_size=model_config["drp_batch_size"], shuffle=True, worker_init_fn = seed_worker, generator = g)

In [43]:
len(target_val_vae_dataset)

54

In [44]:
# initialise the DRP NN 
nn_drp = DRP().to(device)
nn_drp.model_name = "DRP_model"

# Train DRP model
train_drp_model(nn_drp, combined_dataloader, target_dataloader_val_vae, num_epochs=model_config["drp_epochs"], lr=1e-4)

Epoch 0: Training loss: 0.6728090643882751 |  Validation correlation: 1.0274126707105196
Best model
Epoch 1: Training loss: 0.6786614060401917 |  Validation correlation: 1.0328541187008453
Best model
Epoch 2: Training loss: 0.6786597371101379 |  Validation correlation: 1.0377776252856274
Best model
Epoch 3: Training loss: 0.679039865732193 |  Validation correlation: 1.04215750559677
Best model
Epoch 4: Training loss: 0.6790867745876312 |  Validation correlation: 1.0467989933450679
Best model
Epoch 5: Training loss: 0.6655319631099701 |  Validation correlation: 1.0520437702703345
Best model
Epoch 6: Training loss: 0.6601867973804474 |  Validation correlation: 1.0578614191463471
Best model
Epoch 7: Training loss: 0.6714925765991211 |  Validation correlation: 1.0636468053665586
Best model
Epoch 8: Training loss: 0.6560789942741394 |  Validation correlation: 1.0692175215636899
Best model
Epoch 9: Training loss: 0.6649776101112366 |  Validation correlation: 1.0752125579105527
Best model
Epo

In [45]:
nn_drp_trained = DRP().to(device)
nn_drp_trained.model_name = "DRP_model"
nn_drp_trained.load_state_dict(torch.load(f"{folder_config['model_checkpoint_folder']}/{nn_drp_trained.model_name}_{model_config['model_save_criteria']}_{model_config['experiment_id']}_{model_config['experiment_settings']}_fold{model_config['sample_id']}.pth"))

<All keys matched successfully>

In [46]:
nn_drp_trained.eval()

DRP(
  (drug_embedder): Sequential(
    (0): Linear(in_features=2048, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=64, bias=True)
  )
  (recist_predictor): Sequential(
    (0): Linear(in_features=128, out_features=16, bias=True)
    (1): ReLU()
    (2): Linear(in_features=16, out_features=1, bias=True)
  )
)

In [47]:
train_target_inputs_vae

Unnamed: 0_level_0,vae_feat0,vae_feat1,vae_feat2,vae_feat3,vae_feat4,vae_feat5,vae_feat6,vae_feat7,vae_feat8,vae_feat9,...,vae_feat54,vae_feat55,vae_feat56,vae_feat57,vae_feat58,vae_feat59,vae_feat60,vae_feat61,vae_feat62,vae_feat63
sample_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
TCGA-FD-A6TC,0.827066,0.829374,-0.537645,-0.268865,-0.173853,0.343706,-0.805370,-1.572857,-0.061689,-0.202185,...,1.307518,-0.349245,-1.422783,2.172912,-0.723374,0.348981,0.084173,-1.352109,0.891706,0.161657
TCGA-S9-A6TS,-1.168053,1.027555,-0.579022,-1.026176,-0.102448,-0.453966,0.598107,0.896829,0.488383,1.427711,...,1.247069,-0.229513,-0.194794,-1.079690,0.161465,-0.194728,0.522575,0.228321,-1.479867,0.966528
TCGA-VR-A8EQ,-0.090583,0.591559,-0.455180,-1.390575,-0.111734,-0.278430,0.602076,-1.167248,-0.446833,0.249855,...,0.524843,-0.406802,-1.334923,-0.526341,1.106051,-1.428958,-0.719490,0.416061,-0.597750,-0.981432
s_DS_bkm_034_T,-0.078913,0.721081,0.075001,1.141696,0.283392,-0.559048,-1.048769,-2.030850,0.068298,0.297521,...,-1.453830,-0.796550,-0.059814,0.056333,0.547968,-0.462261,0.684337,-0.810644,0.058771,0.348621
TCGA-YU-A90Q,0.749110,-0.045222,-0.673927,1.151813,-0.113674,-2.515810,1.071847,-1.909407,-1.099816,-0.826508,...,2.060330,-0.276407,0.043520,1.048208,1.495172,-1.268675,0.073049,-1.511822,1.500624,-0.640142
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
s_DS_bkm_013_T,-0.473812,2.009128,0.696600,-0.911182,-0.377678,-0.716873,-1.001580,-0.127627,0.929186,-0.480855,...,0.276599,-0.957140,-0.501739,-1.510167,0.535347,-0.912230,1.088135,0.770876,-0.284230,-1.023949
TCGA-GN-A8LK,-0.208496,-0.657187,-1.991301,1.059301,0.776175,-0.817379,1.917014,-0.867422,0.401307,-0.798327,...,0.099852,0.987325,0.585151,1.185973,-0.998058,0.887481,0.633309,-0.572576,-1.235394,0.429780
TCGA-VS-A8EJ,-2.754368,0.892841,0.794190,0.099784,-1.127343,1.096536,-1.025483,-0.043790,-0.346405,-0.342197,...,0.498418,0.888108,1.523046,0.711301,-0.724520,0.469093,-0.487630,1.294133,-0.204178,0.230496
P-0021780-T01-IM6,0.841921,0.567212,-0.702866,0.371914,1.181742,0.974900,0.973902,-0.757085,0.658399,-1.156903,...,1.859428,-0.009184,1.437440,-0.065952,-0.984532,-1.046734,0.148071,-0.789273,0.245715,0.311866


In [48]:
# target test
target_test_vae_dataset = CustomCombinedDataSetLabelled(test_target_data_merged, cl_augmented_df, test_target_inputs_vae, drug_fp)
target_dataloader_test_vae = DataLoader(target_test_vae_dataset, batch_size=model_config["drp_batch_size"], shuffle=False)

In [49]:
len(target_test_vae_dataset)

114

In [50]:
y_test_pred, test_y = inference_drp_model(nn_drp_trained, target_dataloader_test_vae)

In [51]:
res_df = pd.DataFrame()
res_df["y_pred"] = y_test_pred.cpu().detach().numpy().reshape(-1)
res_df["y_true"] = test_y.cpu().detach().numpy().reshape(-1)

In [52]:
res_df

Unnamed: 0,y_pred,y_true
0,0.398731,0
1,0.380655,1
2,0.385387,0
3,0.371898,0
4,0.379821,0
...,...,...
109,0.409117,1
110,0.447311,0
111,0.418613,0
112,0.419422,0


In [53]:
from sklearn.metrics import roc_auc_score, average_precision_score

In [54]:
roc_auc_score(res_df["y_true"], res_df["y_pred"])

0.642361111111111

In [55]:
average_precision_score(res_df["y_true"], res_df["y_pred"])

0.6029306670230963

In [56]:
res_df.to_csv(f"{folder_config['model_checkpoint_folder']}/prediction_patients_val_corr_2A_ALL_fold{fold}.csv")