Uses Morgan fingerprints, VAE repr for cell lines and patients, for ClinVar, GPD and Annovar based annotations, to perform inference from a pretrained DruID model.

### Load libraries

In [None]:
import sys

sys.path.append("../src/")

In [None]:
import numpy as np
import pandas as pd

import torch

from torch import nn
from torch.nn import functional as F

from functools import cached_property

from torch.nn import Linear, ReLU, Sequential

from sklearn.metrics import average_precision_score, ndcg_score, roc_auc_score
from torch.utils.data import Dataset,DataLoader
import torch.nn as nn
import torch.optim as optim

from datasets_drug_filtered import (
    AggCategoricalAnnotatedCellLineDatasetFilteredByDrug,
    AggCategoricalAnnotatedTcgaDatasetFilteredByDrug,
)

from utils import get_kld_loss, get_zinb_loss

from seaborn import scatterplot

from sklearn.metrics import pairwise_distances

In [None]:
# To avoid randomness in DataLoaders - https://pytorch.org/docs/stable/notes/randomness.html
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)
    
g = torch.Generator()
g.manual_seed(0)

In [None]:
sample_id = 0

### Model Definition

In [None]:
from vae import vae

In [None]:
class CellLineEmbedder(nn.Module):
    @cached_property
    def device(self):
        return torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    
    def __init__(
        self,
        checkpoint_base_path="../data/model_checkpoints",
    ):
        super(CellLineEmbedder, self).__init__()
        self.checkpoint_base_path = checkpoint_base_path

        input_dim_vae = 324 * 6 * 4
        k_list = [128, 64]
        actf_list = ["tanh", "tanh"]
        is_real = True

        # The below modules are expected to be available in the scope where this module is instialized

        self.vae_model1 = vae(input_dim_vae, k_list, actf_list, is_real)
        self.vae_model1 = (
            self.vae_model1.cuda() if self.device.type == "cuda" else self.vae_model1
        )

        self.vae_model2 = vae(input_dim_vae, k_list, actf_list, is_real)
        self.vae_model2 = (
            self.vae_model2.cuda() if self.device.type == "cuda" else self.vae_model2
        )

        self.vae_model1_raw_mutation = vae(input_dim_vae, k_list, actf_list, is_real)
        self.vae_model1_raw_mutation = (
            self.vae_model1_raw_mutation.cuda()
            if self.device.type == "cuda"
            else self.vae_model1_raw_mutation
        )

        self.vae_model2_raw_mutation = vae(input_dim_vae, k_list, actf_list, is_real)
        self.vae_model2_raw_mutation = (
            self.vae_model2_raw_mutation.cuda()
            if self.device.type == "cuda"
            else self.vae_model2_raw_mutation
        )

    def __str__(self):
        return f"CellLineEmbedder"

    def load_model(self):
        
        self.vae_model1_raw_mutation.load_state_dict(
            torch.load(
                f"{self.checkpoint_base_path}/druid_with_tcga_filtered_drug_sample{sample_id}_unsupervised_vae_model_cell_line_domain_clinvar_gpd_annovar_annotated_v4.pt",
                map_location=str(self.device),
            )
        )
        
        self.vae_model2_raw_mutation.load_state_dict(
            torch.load(
                f"{self.checkpoint_base_path}/druid_with_tcga_filtered_drug_sample{sample_id}_unsupervised_vae_model_other_domain_clinvar_gpd_annovar_annotated_v4.pt",
                map_location=str(self.device),
            )
        )
        
        
    def forward(self, x):
        # Get cell line representation from annotated encoder
        _, cell_line_emb, _, _ = self.vae_model1_raw_mutation(x)
        return cell_line_emb

In [None]:
class PatientEmbedder(nn.Module):
    @cached_property
    def device(self):
        return torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    
    def __init__(
        self,
        checkpoint_base_path="../data/model_checkpoints",
    ):
        super(PatientEmbedder, self).__init__()
        self.checkpoint_base_path = checkpoint_base_path

        input_dim_vae = 324 * 6 * 4
        k_list = [128, 64]
        actf_list = ["tanh", "tanh"]
        is_real = True

        # The below modules are expected to be available in the scope where this module is instialized

        self.vae_model1 = vae(input_dim_vae, k_list, actf_list, is_real)
        self.vae_model1 = (
            self.vae_model1.cuda() if self.device.type == "cuda" else self.vae_model1
        )

        self.vae_model2 = vae(input_dim_vae, k_list, actf_list, is_real)
        self.vae_model2 = (
            self.vae_model2.cuda() if self.device.type == "cuda" else self.vae_model2
        )

        self.vae_model1_raw_mutation = vae(input_dim_vae, k_list, actf_list, is_real)
        self.vae_model1_raw_mutation = (
            self.vae_model1_raw_mutation.cuda()
            if self.device.type == "cuda"
            else self.vae_model1_raw_mutation
        )

        self.vae_model2_raw_mutation = vae(input_dim_vae, k_list, actf_list, is_real)
        self.vae_model2_raw_mutation = (
            self.vae_model2_raw_mutation.cuda()
            if self.device.type == "cuda"
            else self.vae_model2_raw_mutation
        )

    def __str__(self):
        return f"PatientEmbedder"

    def load_model(self):
        
        self.vae_model1_raw_mutation.load_state_dict(
            torch.load(
                f"{self.checkpoint_base_path}/druid_with_tcga_filtered_drug_sample{sample_id}_unsupervised_vae_model_cell_line_domain_clinvar_gpd_annovar_annotated_v4.pt",
                map_location=str(self.device),
            )
        )
        
        self.vae_model2_raw_mutation.load_state_dict(
            torch.load(
                f"{self.checkpoint_base_path}/druid_with_tcga_filtered_drug_sample{sample_id}_unsupervised_vae_model_other_domain_clinvar_gpd_annovar_annotated_v4.pt",
                map_location=str(self.device),
            )
        )
        
        
    def forward(self, x):
        # Get patient representation from annotated encoder
        _, patient_emb, _, _ = self.vae_model2_raw_mutation(x)
        return patient_emb

In [None]:
# drug_names = [drug_name]
drug_names = ['CISPLATIN', 'PACLITAXEL']
uniq_drug_names = np.unique(np.array(drug_names))
drug_names_to_idx_map = dict(zip(uniq_drug_names, range(len(uniq_drug_names))))

In [None]:
drug_fp = pd.read_csv("../data/processed/drug_morgan_fingerprints.csv", index_col=0)
drug_fp

### Creating the datasets

In [None]:
tcga_dataset_test = AggCategoricalAnnotatedTcgaDatasetFilteredByDrug(is_train=False, filter_for="tcga", sample_id=sample_id)
tcga_dataset_test.tcga_response

In [None]:
tcga_test_features = []
tcga_test_y = []
for idx, row in tcga_dataset_test.tcga_response.iterrows():
    row_inp = []
    row_inp.extend(tcga_dataset_test.clinvar_gpd_annovar_annotated.loc[row["submitter_id"]].values)
    row_inp.extend(drug_fp.loc[row["drug_name"]].values)
    row_inp.append(row["response"])
    tcga_test_y.append(row["response"])
    tcga_test_features.append(row_inp)

In [None]:
len(row_inp)

In [None]:
from torch.utils.data import DataLoader, Dataset

In [None]:
# load basic set of params
from dotmap import DotMap
import yaml
import wandb
with open(f'./config/config_tcga_sample{sample_id}.yml', 'r') as f:
    args = DotMap(yaml.safe_load(f))
print(args)
np.random.seed(args.seed)

In [None]:
seed = args.seed

In [None]:
# finetuned params
lr_main_optim = 1e-5
lr_cl_optim = 1e-4
lr_patient_optim = 1e-3
lr_drug_optim = 1e-4
args.lr_main_optim = lr_main_optim
args.lr_cl_optim = lr_cl_optim
args.lr_patient_optim = lr_patient_optim
args.lr_drug_optim = lr_drug_optim
args.device = 1
args.epochs = 50
batch_size = args.batch_size

In [None]:
args

In [None]:
class CustomDataset(Dataset):
    def __init__(self, train_features):
        self.train_features = train_features

    def __len__(self):
        return len(self.train_features)

    def __getitem__(self, idx):
        return torch.Tensor(self.train_features[idx][:-1]), self.train_features[idx][-1]

In [None]:
tcga_test_data = CustomDataset(tcga_test_features)
tcga_test_dataloader = DataLoader(tcga_test_data, batch_size=batch_size, shuffle=False, generator=g, worker_init_fn=seed_worker)

In [None]:
len(tcga_test_features)

In [None]:
test_loaders = {
    "RECIST_prediction": tcga_test_dataloader,
}

In [None]:
datasets = {
                'RECIST_prediction': ['tcga'],
                'AUDRC_prediction': ['ccle'],
}

### DruID Model

In [None]:
class DruID(nn.Module):
    '''
    Used for training 2 tasks - cell line-drug AUDRC prediction(regression) and patient-drug RECIST prediction(classification)
    300 dimensional input for drugs
    '''
    def __init__(self,single=False):
        super(DruID, self).__init__()
        # drug embedder network
        self.drug_embedder = self.fnn(2048, 128, 64, 32)
        # cell line embedder network
        self.cell_line_embedder = CellLineEmbedder(checkpoint_base_path=f'../data/model_checkpoints/')
        # self.cell_line_embedder.load_model() # load pretrained VAE model
        # patient embedder network
        self.patient_embedder = PatientEmbedder(checkpoint_base_path=f'../data/model_checkpoints/')
        # self.patient_embedder.load_model() # load pretrained VAE model
        # prediction heads
        self.recist_predictor = nn.Sequential(self.fnn(64, 64, 16, 1), ) # takes as input concatenated representation of patient and drug
        self.audrc_predictor = nn.Sequential(self.fnn(64, 64, 16, 1), ) # takes as input concatenated representation of cell line and drug

        self.AUDRC_specific = nn.ModuleDict({'embedder': self.cell_line_embedder,
                                              'predictor': self.audrc_predictor})
        self.RECIST_specific = nn.ModuleDict({'embedder': self.patient_embedder,
                                                'predictor': self.recist_predictor})

        self.name = 'DruID'
        self.device = torch.device(f'cuda:1' if torch.cuda.is_available() else 'cpu')
        
    def fnn(self, In, hidden1, hidden2, out):
        return nn.Sequential(nn.Linear(In, hidden1), nn.ReLU(), 
                             nn.Linear(hidden1, hidden2), nn.ReLU(), 
                             nn.Linear(hidden2, out))

    def forward(self,x1,x2): # x1 is patient data, x2 is cell lines - each row is of the form [mutations 324*24*6, drug fingerprint of 2048 dim]
        # input is of dim (batch_size, 7776+2048)
        # drug input
        patient_drug_input = x1[:, 324*6*4:].to(self.device, torch.float32)
        cl_drug_input = x2[:, 324*6*4:].to(self.device, torch.float32)
    
        # mutation profile
        patient_mut_input = torch.Tensor(x1[:,:324*6*4]).to(self.device, torch.float32)
        cl_mut_input = torch.Tensor(x2[:,:324*6*4]).to(self.device, torch.float32)
        
        # drug embedding
        patient_drug_emb = self.drug_embedder(patient_drug_input)
        cl_drug_emb = self.drug_embedder(cl_drug_input)
        
        # mutation embedding
        patient_mut_emb = self.patient_embedder(patient_mut_input)
        cl_mut_emb = self.cell_line_embedder(cl_mut_input)
        
        # concat and pass through prediction heads
        patient_drug_cat_emb = torch.cat((patient_mut_emb, patient_drug_emb), dim=1)
        cl_drug_cat_emb = torch.cat((cl_mut_emb, cl_drug_emb), dim=1)
        
        # prediction heads
        recist_prediction = self.recist_predictor(patient_drug_cat_emb)
        audrc_prediction = self.audrc_predictor(cl_drug_cat_emb)
        
        return recist_prediction, audrc_prediction


In [None]:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

tasks = ['RECIST_prediction', 'AUDRC_prediction']  # prediction tasks; model consumes in this order; important

# model related
model = eval(f'{args.model}()')
device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')
print(f'Using {device} device...')
model = model.to(device)

## Pretrained DruID model related:
# Uncomment below line if loading pretrained model weights for DruID (Eg: in TCGA, we load pretrained weights from IMAC-OV DruID model)
# model.load_state_dict(torch.load(f"/data/ajayago/druid/paper_intermediate_pretrained/model_checkpoints/druid_MTL_raw_mutations_sample{sample_id}/rad51_drug_fp_clinvar_gpd_annovar_annotated_drug_trs_{wand_run_2_load}_all_drugs.pth"))
# Uncomment below lines if freeze some layers like drug, cell line and patient embedder etc from pretrained model
# for param in model.drug_embedder.parameters():
#     param.requires_grad = False
# for param in model.cell_line_embedder.parameters():
#     param.requires_grad = False
# for param in model.patient_embedder.parameters():
#     param.requires_grad = False

specific_submodels = {
                      'RECIST_prediction': model.RECIST_specific,
                      'AUDRC_prediction': model.AUDRC_specific
                     }
common_submodel = model.drug_embedder


### Inference

In [None]:
# comment if training
model.load_state_dict(torch.load(f"../data/model_checkpoints/DruID.pth"))

In [None]:
model

In [None]:
model.eval()
y_preds = []
for idx, (inp, y) in enumerate(tcga_test_dataloader):
    # drug input
    patient_drug_input = inp[:, 324*6*4:].to(device, torch.float32)
    # mutation profile
    patient_mut_input = inp[:, :324*6*4].to(device, torch.float32)

    # drug embedding
    patient_drug_emb = model.drug_embedder(patient_drug_input)

    # mutation embedding
    patient_mut_emb = model.patient_embedder(patient_mut_input)

    # concat and pass through prediction heads
    patient_drug_cat_emb = torch.cat((patient_mut_emb, patient_drug_emb), dim=1)

    recist_prediction = model.recist_predictor(patient_drug_cat_emb)
    y_preds.extend(list(recist_prediction.flatten().detach().cpu().numpy()))

### Metrics

In [None]:
from scipy import stats
from numpy import argmax
from sklearn.metrics import roc_curve

In [None]:
len(y_preds)

In [None]:
y_true = tcga_dataset_test.tcga_response[tcga_dataset_test.tcga_response.drug_name == "CISPLATIN"]
y_pred = tcga_dataset_test.tcga_response.copy()
y_pred["response"] = y_preds
y_pred

In [None]:

y_pred_pivotted = y_pred.pivot_table(
                "response", "submitter_id", "drug_name"
            )
y_pred_pivotted = y_pred_pivotted.fillna(0) # in case there are NaNs
dict_idx_drug = pd.DataFrame(y_pred_pivotted.columns).to_dict()["drug_name"]
dict_id_drug = {}

for patient_id, predictions in y_pred_pivotted.iterrows():

    cur_pred_scores = predictions.values
    cur_recom_drug_idx = np.argsort(cur_pred_scores)[:-11:-1]
    #
    dict_recom_drug = {}
    for idx, cur_idx in enumerate(cur_recom_drug_idx):
        dict_recom_drug[
            dict_idx_drug[cur_idx]
        ] = f"{cur_pred_scores[cur_idx]} ({idx+1})"
    #
    dict_id_drug[patient_id] = dict_recom_drug

predictions_display_tcga = pd.DataFrame.from_dict(dict_id_drug)

na_mask = y_pred.response.isna()
if na_mask.sum():
    print(
        f"Found {na_mask.sum()} rows with invalid response values"
    )
    y_pred = y_pred[~na_mask]
    y_true = y_true.loc[~(na_mask.values)]
na_mask = y_true.response.isna()
y_true = y_true[~na_mask]
y_pred = y_pred[~na_mask]
print(y_pred.shape)
y_pred.head()
y_combined = y_pred.merge(y_true, on=["submitter_id", "drug_name"])

from sklearn.metrics import average_precision_score, ndcg_score, roc_auc_score, f1_score, accuracy_score, precision_score, recall_score

drugs_with_enough_support = ["CISPLATIN"]

for drug_name in drugs_with_enough_support:
    try:
        roc = roc_auc_score(
            y_true[y_true.drug_name == drug_name].response.values,
            y_pred[y_pred.drug_name == drug_name].response.values,
            average="micro",
        )
        aupr = average_precision_score(
            y_true[y_true.drug_name == drug_name].response.values,
            y_pred[y_pred.drug_name == drug_name].response.values,
            average="micro",
        )
        # Choosing the right threshold for F1, accuracy and precision calculation from ref: https://machinelearningmastery.com/threshold-moving-for-imbalanced-classification/
        fpr, tpr, thresholds = roc_curve(
            y_true[y_true.drug_name == drug_name].response.values,
            y_pred[y_pred.drug_name == drug_name].response.values,
        )
        J = tpr - fpr
        ix = argmax(J)
        best_thresh = thresholds[ix]
        
        f1 = f1_score(
            y_true[y_true.drug_name == drug_name].response.values,
            (y_pred[y_pred.drug_name == drug_name].response.values > best_thresh).astype(int),
        )
        acc_score = accuracy_score(
            y_true[y_true.drug_name == drug_name].response.values,
            (y_pred[y_pred.drug_name == drug_name].response.values > best_thresh).astype(int),
        )
        prec_score = precision_score(
            y_true[y_true.drug_name == drug_name].response.values,
            (y_pred[y_pred.drug_name == drug_name].response.values > best_thresh).astype(int),
        )
        rec_score = recall_score(
            y_true[y_true.drug_name == drug_name].response.values,
            (y_pred[y_pred.drug_name == drug_name].response.values > best_thresh).astype(int),
        )
        spearman_stats = stats.spearmanr(
            y_true[y_true.drug_name == drug_name].response.values,
            y_pred[y_pred.drug_name == drug_name].response.values,
        )
        mw_stats = stats.mannwhitneyu(
            y_combined[
                (y_combined.drug_name == drug_name) & (y_combined.response_y == 0)
            ].response_x.values,
            y_combined[
                (y_combined.drug_name == drug_name) & (y_combined.response_y == 1)
            ].response_x.values,
            alternative="greater",
        )
        denominator = (
            y_combined[
                (y_combined.drug_name == drug_name) & (y_combined.response_y == 0)
            ].shape[0]
            * y_combined[
                (y_combined.drug_name == drug_name) & (y_combined.response_y == 1)
            ].shape[0]
        )
        print(f"AUROC for {drug_name}: {roc}")
        print(f"AUPR for {drug_name}: {aupr}")
        print(f"F1 for {drug_name}: {f1}")
        print(f"Accuracy Score for {drug_name}: {acc_score}")
        print(f"Precision Score for {drug_name}: {prec_score}")
        print(f"Recall Score for {drug_name}: {rec_score}")
        print(
            f"Spearman for {drug_name}: {round(spearman_stats.correlation, 4)} (p-val: {round(spearman_stats.pvalue, 4)})"
        )
        print(
            f"Mann-Whitney for {drug_name}: {round(mw_stats.statistic/denominator, 4)} (p-val: {round(mw_stats.pvalue, 4)})"
        )
    except Exception as e:
        print(f"Error processing {drug_name} - {e}")

