Uses Morgan fingerprints, VAE repr for cell lines and patients, for ClinVar, GPD and Annovar based annotations.

This does drug specific inference, loads model trained for this dataset.

In [1]:
import sys

sys.path.append("../src/")

In [2]:
import numpy as np
import pandas as pd

import datetime
import logging
import os
import time
import torch

from torch import nn
from torch.nn import functional as F

from functools import cached_property

from torch.nn import Linear, ReLU, Sequential

from sklearn.metrics import average_precision_score, ndcg_score, roc_auc_score

from torch.utils.data import Dataset,DataLoader
import torch.nn as nn
import torch.optim as optim


from datasets_drug_filtered import (
    AggCategoricalAnnotatedCellLineDatasetFilteredByDrug,
    
    AggCategoricalAnnotatedTcgaDatasetFilteredByDrug,
    
)

from utils import get_kld_loss, get_zinb_loss

from model import (
    BaseDruidModel,
)

from seaborn import scatterplot

from sklearn.metrics import pairwise_distances

In [3]:
# To avoid randomness in DataLoaders - https://pytorch.org/docs/stable/notes/randomness.html
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)
    
g = torch.Generator()
g.manual_seed(0)

<torch._C.Generator at 0x7f4f14c48870>

In [4]:
sample_id = 2

In [5]:
path_to_saved_model = "/data/ajayago/druid/paper/CelliScience/intermediate/DruID/tcga/mtl/Q2/"

In [6]:
drug_name = "5-FLUOROURACIL"

In [7]:
map_sample_drug_to_wandb_name = {
    ("CISPLATIN", 0): "lyric-totem-28",
    ("CISPLATIN", 1): "driven-sea-3",
    ("CISPLATIN", 2): "ethereal-glade-26",
    ("PACLITAXEL", 0): "gentle-firebrand-25",
    ("PACLITAXEL", 1): "elated-brook-4",
    ("PACLITAXEL", 2): "apricot-rain-12",
    ("5-FLUOROURACIL", 0): "firm-butterfly-26",
    ("5-FLUOROURACIL", 1): "stellar-cherry-21",
    ("5-FLUOROURACIL", 2): "ethereal-wildflower-10"
}

In [8]:
wandb_name = map_sample_drug_to_wandb_name[(drug_name, sample_id)]

### Model Definition

In [9]:
from model import (
    BaseDruidModel,
)

In [10]:
from ffnzinb import ffnzinb
from vae import vae

In [11]:
# from torch_geometric import data as DATA
# from torch_geometric.loader import DataLoader as dl

In [12]:
class CellLineEmbedder(nn.Module):
    @cached_property
    def device(self):
        return torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    
    def __init__(
        self,
        checkpoint_base_path="../data/model_checkpoints",
    ):
        super(CellLineEmbedder, self).__init__()
        self.checkpoint_base_path = checkpoint_base_path

        input_dim_vae = 324 * 6 * 4
        # k_list = [1024, 128]
        k_list = [128, 64]
        actf_list = ["tanh", "tanh"]
        is_real = True

        # The below modules are expected to be available in the scope where this module is instialized

        self.vae_model1 = vae(input_dim_vae, k_list, actf_list, is_real)
        self.vae_model1 = (
            self.vae_model1.cuda(device=self.device) if self.device.type == "cuda" else self.vae_model1
        )

        self.vae_model2 = vae(input_dim_vae, k_list, actf_list, is_real)
        self.vae_model2 = (
            self.vae_model2.cuda(device=self.device) if self.device.type == "cuda" else self.vae_model2
        )

        self.vae_model1_raw_mutation = vae(input_dim_vae, k_list, actf_list, is_real)
        self.vae_model1_raw_mutation = (
            self.vae_model1_raw_mutation.cuda(device=self.device)
            if self.device.type == "cuda"
            else self.vae_model1_raw_mutation
        )

        self.vae_model2_raw_mutation = vae(input_dim_vae, k_list, actf_list, is_real)
        self.vae_model2_raw_mutation = (
            self.vae_model2_raw_mutation.cuda(device=self.device)
            if self.device.type == "cuda"
            else self.vae_model2_raw_mutation
        )

    def __str__(self):
        return f"CellLineEmbedder"

    def load_model(self):
#         self.vae_model1.load_state_dict(
#             torch.load(
#                 f"{self.checkpoint_base_path}/unsupervised_vae_model_cell_line_domain.pt",
#                 map_location=str(self.device),
#             )
#         )
        
#         self.vae_model2.load_state_dict(
#             torch.load(
#                 f"{self.checkpoint_base_path}/unsupervised_vae_model_other_domain.pt",
#                 map_location=str(self.device),
#             )
#         )
        
        self.vae_model1_raw_mutation.load_state_dict(
            torch.load(
                f"{self.checkpoint_base_path}/unsupervised_vae_model_cell_line_domain_clinvar_gpd_annovar_annotated.pt",
                map_location=str(self.device),
            )
        )
        
        self.vae_model2_raw_mutation.load_state_dict(
            torch.load(
                f"{self.checkpoint_base_path}/unsupervised_vae_model_other_domain_clinvar_gpd_annovar_annotated.pt",
                map_location=str(self.device),
            )
        )
        
        
    def forward(self, x):
        # Get cell line representation from annotated encoder
        _, cell_line_emb, _, _ = self.vae_model1_raw_mutation(x)
        return cell_line_emb

In [13]:
class PatientEmbedder(nn.Module):
    @cached_property
    def device(self):
        return torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    
    def __init__(
        self,
        checkpoint_base_path="../data/model_checkpoints",
    ):
        super(PatientEmbedder, self).__init__()
        self.checkpoint_base_path = checkpoint_base_path

        input_dim_vae = 324 * 6 * 4
        k_list = [128, 64]
        # k_list = [1024, 128]
        actf_list = ["tanh", "tanh"]
        is_real = True

        # The below modules are expected to be available in the scope where this module is instialized

        self.vae_model1 = vae(input_dim_vae, k_list, actf_list, is_real)
        self.vae_model1 = (
            self.vae_model1.cuda(device=self.device) if self.device.type == "cuda" else self.vae_model1
        )

        self.vae_model2 = vae(input_dim_vae, k_list, actf_list, is_real)
        self.vae_model2 = (
            self.vae_model2.cuda(device=self.device) if self.device.type == "cuda" else self.vae_model2
        )

        self.vae_model1_raw_mutation = vae(input_dim_vae, k_list, actf_list, is_real)
        self.vae_model1_raw_mutation = (
            self.vae_model1_raw_mutation.cuda(device=self.device)
            if self.device.type == "cuda"
            else self.vae_model1_raw_mutation
        )

        self.vae_model2_raw_mutation = vae(input_dim_vae, k_list, actf_list, is_real)
        self.vae_model2_raw_mutation = (
            self.vae_model2_raw_mutation.cuda(device=self.device)
            if self.device.type == "cuda"
            else self.vae_model2_raw_mutation
        )

    def __str__(self):
        return f"PatientEmbedder"

    def load_model(self):
#         self.vae_model1.load_state_dict(
#             torch.load(
#                 f"{self.checkpoint_base_path}/unsupervised_vae_model_cell_line_domain.pt",
#                 map_location=str(self.device),
#             )
#         )
        
#         self.vae_model2.load_state_dict(
#             torch.load(
#                 f"{self.checkpoint_base_path}/unsupervised_vae_model_other_domain.pt",
#                 map_location=str(self.device),
#             )
#         )
        
        self.vae_model1_raw_mutation.load_state_dict(
            torch.load(
                f"{self.checkpoint_base_path}/unsupervised_vae_model_cell_line_domain_clinvar_gpd_annovar_annotated.pt",
                map_location=str(self.device),
            )
        )
        
        self.vae_model2_raw_mutation.load_state_dict(
            torch.load(
                f"{self.checkpoint_base_path}/unsupervised_vae_model_other_domain_clinvar_gpd_annovar_annotated.pt",
                map_location=str(self.device),
            )
        )
        
        
    def forward(self, x):
        # Get patient representation from annotated encoder
        _, patient_emb, _, _ = self.vae_model2_raw_mutation(x)
        return patient_emb

In [14]:
# drug_names = ['DOCETAXEL', 'GEMCITABINE', 'CISPLATIN', 'PACLITAXEL', '5-FLUOROURACIL', 'CYCLOPHOSPHAMIDE']
drug_names = [drug_name]
uniq_drug_names = np.unique(np.array(drug_names))
drug_names_to_idx_map = dict(zip(uniq_drug_names, range(len(uniq_drug_names))))


In [15]:
uniq_drug_names

array(['5-FLUOROURACIL'], dtype='<U14')

In [16]:
drug_fp = pd.read_csv("../data/processed/drug_morgan_fingerprints.csv", index_col=0)
drug_fp

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,2038,2039,2040,2041,2042,2043,2044,2045,2046,2047
drug_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
JW-7-24-1,0,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
KIN001-260,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
NSC-87877,0,0,0,0,0,0,0,0,0,0,...,1,0,0,0,0,0,0,0,0,0
GNE-317,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
NAVITOCLAX,0,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
LGH447,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
TRASTUZUMAB,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
WNT974,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
TRIFLURIDINE,0,0,0,0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


### Creating the datasets

In [17]:
cl_dataset_train = AggCategoricalAnnotatedCellLineDatasetFilteredByDrug(is_train=True, filter_for="tcga", sample_id=sample_id)
cl_dataset_train.y_df

Unnamed: 0,depmap_id,drug_name,auc
0,ACH-000001,5-FLUOROURACIL,0.949220
1,ACH-000002,5-FLUOROURACIL,0.947982
2,ACH-000007,5-FLUOROURACIL,0.854078
3,ACH-000008,5-FLUOROURACIL,0.979834
6,ACH-000020,5-FLUOROURACIL,0.990438
...,...,...,...
3301,ACH-001702,PACLITAXEL,0.687727
3302,ACH-001703,PACLITAXEL,0.422809
3303,ACH-001711,PACLITAXEL,0.599059
3304,ACH-001715,PACLITAXEL,0.799246


In [18]:
cl_train_features = []
cl_train_y = []
for idx, row in cl_dataset_train.y_df.iterrows():
    if row["drug_name"] == drug_name:
        row_inp = []
        row_inp.extend(cl_dataset_train.clinvar_gpd_annovar_annotated.loc[row["depmap_id"]].values)
        row_inp.extend(drug_fp.loc[row["drug_name"]].values)
        row_inp.append(row["auc"])
        cl_train_y.append(row["auc"])
        cl_train_features.append(row_inp)

In [19]:
len(row_inp)

9825

In [20]:
cl_dataset_test = AggCategoricalAnnotatedCellLineDatasetFilteredByDrug(is_train=False, filter_for="tcga", sample_id=sample_id)
cl_dataset_test.y_df

Unnamed: 0,depmap_id,drug_name,auc
0,ACH-000004,5-FLUOROURACIL,0.993615
1,ACH-000006,5-FLUOROURACIL,0.964677
2,ACH-000009,5-FLUOROURACIL,0.819141
3,ACH-000015,5-FLUOROURACIL,0.984787
4,ACH-000019,5-FLUOROURACIL,0.970469
...,...,...,...
823,ACH-001385,PACLITAXEL,0.707845
824,ACH-001401,PACLITAXEL,0.530202
825,ACH-001525,PACLITAXEL,0.972188
826,ACH-001542,PACLITAXEL,0.485053


In [21]:
cl_test_features = []
cl_test_y = []
for idx, row in cl_dataset_test.y_df.iterrows():
    if row["drug_name"] == drug_name:
        row_inp = []
        row_inp.extend(cl_dataset_test.clinvar_gpd_annovar_annotated.loc[row["depmap_id"]].values)
        row_inp.extend(drug_fp.loc[row["drug_name"]].values)
        row_inp.append(row["auc"])
        cl_test_y.append(row["auc"])
        cl_test_features.append(row_inp)

In [22]:
len(row_inp) # 324 gene mutations + 2048 len fingerprint + 1 AUDRC

9825

In [23]:
tcga_dataset_train = AggCategoricalAnnotatedTcgaDatasetFilteredByDrug(is_train=True, filter_for="tcga", sample_id=sample_id)
tcga_dataset_train.tcga_response

Unnamed: 0,submitter_id,drug_name,response
1,TCGA-G2-A2EJ,PACLITAXEL,0
2,TCGA-G2-A2EJ,CISPLATIN,0
3,TCGA-G2-A2EF,CISPLATIN,1
4,TCGA-G2-A2EK,CISPLATIN,0
5,TCGA-G2-A2EL,PACLITAXEL,0
...,...,...,...
621,TCGA-QS-A5YQ,CISPLATIN,1
622,TCGA-QS-A5YQ,PACLITAXEL,1
625,TCGA-2E-A9G8,PACLITAXEL,1
626,TCGA-BG-A0VZ,CISPLATIN,1


In [24]:
tcga_train_features = []
tcga_train_y = []
for idx, row in tcga_dataset_train.tcga_response.iterrows():
    if row["drug_name"] == drug_name:
        row_inp = []
        row_inp.extend(tcga_dataset_train.clinvar_gpd_annovar_annotated.loc[row["submitter_id"]].values)
        row_inp.extend(drug_fp.loc[row["drug_name"]].values)

        row_inp.append(row["response"])
        tcga_train_y.append(row["response"])
        tcga_train_features.append(row_inp)

In [25]:
len(row_inp)

9825

In [26]:
tcga_dataset_test = AggCategoricalAnnotatedTcgaDatasetFilteredByDrug(is_train=False, filter_for="tcga", sample_id=sample_id)
tcga_dataset_test.tcga_response

Unnamed: 0,submitter_id,drug_name,response
0,TCGA-G2-A2EC,CISPLATIN,1
7,TCGA-BT-A2LD,PACLITAXEL,0
11,TCGA-G2-A3IE,GEMCITABINE,1
12,TCGA-G2-A3IE,CISPLATIN,0
13,TCGA-GV-A3JW,GEMCITABINE,1
...,...,...,...
608,TCGA-AJ-A3I9,PACLITAXEL,0
619,TCGA-EY-A54A,PACLITAXEL,1
623,TCGA-E6-A8L9,PACLITAXEL,1
624,TCGA-QS-A8F1,PACLITAXEL,0


In [27]:
tcga_test_features = []
tcga_test_y = []
for idx, row in tcga_dataset_test.tcga_response.iterrows():
    if row["drug_name"] == drug_name:
        row_inp = []
        row_inp.extend(tcga_dataset_test.clinvar_gpd_annovar_annotated.loc[row["submitter_id"]].values)
        row_inp.extend(drug_fp.loc[row["drug_name"]].values)
        row_inp.append(row["response"])
        tcga_test_y.append(row["response"])
        tcga_test_features.append(row_inp)

In [28]:
len(row_inp)

9825

In [29]:
class DruID(nn.Module):
    '''
    Used for training 2 tasks - cell line-drug AUDRC prediction(regression) and patient-drug RECIST prediction(classification)
    300 dimensional input for drugs
    '''
    def __init__(self,single=False):
        super(DruID, self).__init__()
#         self.drug_embedder = GINConvNet(output_dim=8)
        self.drug_embedder = self.fnn(2048, 128, 64, 32) # original
#         self.drug_embedder = self.ffn(2048, 8)
        self.cell_line_embedder = CellLineEmbedder(checkpoint_base_path=f'/data/ajayago/druid/paper_intermediate/model_checkpoints/2B_druid_with_tcga_filtered_drug_sample{sample_id}/')
        # self.cell_line_embedder.load_model()
        self.patient_embedder = PatientEmbedder(checkpoint_base_path=f'/data/ajayago/druid/paper_intermediate/model_checkpoints/2B_druid_with_tcga_filtered_drug_sample{sample_id}/')
        # self.patient_embedder.load_model()

        self.recist_predictor = nn.Sequential(self.fnn(64, 64, 16, 1), ) # original
        self.audrc_predictor = nn.Sequential(self.fnn(64, 64, 16, 1), ) # original

        self.AUDRC_specific = nn.ModuleDict({'embedder': self.cell_line_embedder,
                                              'predictor': self.audrc_predictor})
        self.RECIST_specific = nn.ModuleDict({'embedder': self.patient_embedder,
                                                'predictor': self.recist_predictor})

        self.name = 'DrugTRS - train AUDRC and RECIST together '
        drug_names = ['DOCETAXEL', 'GEMCITABINE', 'CISPLATIN', 'PACLITAXEL', '5-FLUOROURACIL', 'CYCLOPHOSPHAMIDE']
        uniq_drug_names = np.unique(np.array(drug_names))
        self.device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')
        
        
    def fnn(self, In, hidden1, hidden2, out):
        return nn.Sequential(nn.Linear(In, hidden1), nn.ReLU(), #nn.BatchNorm1d(hidden1),
                             nn.Linear(hidden1, hidden2), nn.ReLU(), #nn.BatchNorm1d(hidden2),
                             nn.Linear(hidden2, out))

    def forward(self,x1,x2): # x1 is Rad51, x2 is cell lines - each row is of the form [mutation 324, drug fp]
        # input is of dim (batch_size, 325)
        # drug input
        patient_drug_input = x1[:, 324*6*4:].to(self.device, torch.float32)
        cl_drug_input = x2[:, 324*6*4:].to(self.device, torch.float32)
    
        # mutation profile
        patient_mut_input = torch.Tensor(x1[:,:324*6*4]).to(self.device, torch.float32)
        cl_mut_input = torch.Tensor(x2[:,:324*6*4]).to(self.device, torch.float32)
        
        # drug embedding

        patient_drug_emb = self.drug_embedder(patient_drug_input)
        cl_drug_emb = self.drug_embedder(cl_drug_input)
        
        # mutation embedding
        patient_mut_emb = self.patient_embedder(patient_mut_input)
        cl_mut_emb = self.cell_line_embedder(cl_mut_input)
        
        # concat and pass through prediction heads
        patient_drug_cat_emb = torch.cat((patient_mut_emb, patient_drug_emb), dim=1)
        cl_drug_cat_emb = torch.cat((cl_mut_emb, cl_drug_emb), dim=1)
        
        recist_prediction = self.recist_predictor(patient_drug_cat_emb)
        audrc_prediction = self.audrc_predictor(cl_drug_cat_emb)
        
        return recist_prediction, audrc_prediction


In [30]:
from dotmap import DotMap
import yaml
import wandb
with open(f'../notebook/config/config_tcga_sample0.yml', 'r') as f: # placeholder use sample 0 - not used in the model anyway
    args = DotMap(yaml.safe_load(f))
print(args)
np.random.seed(args.seed)

DotMap(model='DruID', batch_size=256, epochs=100, moo='CS', inv_preference=[0.1, 1], seed=0, device=0, valid_interval=100)


In [31]:
seed = args.seed

In [32]:
# lr_main_optim = 1e-5
# lr_cl_optim = 1e-4
# lr_patient_optim = 1e-3
# lr_drug_optim = 1e-4
# args.lr_main_optim = lr_main_optim
# args.lr_cl_optim = lr_cl_optim
# args.lr_patient_optim = lr_patient_optim
# args.lr_drug_optim = lr_drug_optim
# args.epochs = 500

lr_main_optim = 1e-6
lr_cl_optim = 1e-4
lr_patient_optim = 1e-3
lr_drug_optim = 1e-4
args.lr_main_optim = lr_main_optim
args.lr_cl_optim = lr_cl_optim
args.lr_patient_optim = lr_patient_optim
args.lr_drug_optim = lr_drug_optim
args.epochs = 500
args.device = 1

In [33]:
args

DotMap(model='DruID', batch_size=256, epochs=500, moo='CS', inv_preference=[0.1, 1], seed=0, device=1, valid_interval=100, lr_main_optim=1e-06, lr_cl_optim=0.0001, lr_patient_optim=0.001, lr_drug_optim=0.0001, _ipython_display_=DotMap(), _repr_mimebundle_=DotMap())

In [34]:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

tasks = ['RECIST_prediction', 'AUDRC_prediction']  # prediction tasks; model consumes in this order; important

# model
model = eval(f'{args.model}()')
device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')
print(f'Using {device} device...')
model = model.to(device)
specific_submodels = {
                      'RECIST_prediction': model.RECIST_specific,
                      'AUDRC_prediction': model.AUDRC_specific
                     }
common_submodel = model.drug_embedder

# optimization related
batch_size = args.batch_size
# optimizer_main = optim.Adam(model.parameters(), lr = lr_main_optim)
optimizer_main = optim.Adam(list(model.audrc_predictor.parameters())+
                            list(model.recist_predictor.parameters())
                            , lr=lr_main_optim)  # , lr=1e-2)=
optimizer_drug = optim.Adam(model.drug_embedder.parameters(), lr = lr_drug_optim)
optimizer_cl = optim.Adam(model.cell_line_embedder.parameters(), lr=lr_cl_optim)
optimizer_patient = optim.Adam(model.patient_embedder.parameters(), lr=lr_patient_optim)
criteria = {
            'RECIST_prediction': nn.BCEWithLogitsLoss(),
            'AUDRC_prediction': nn.MSELoss(),
           }


U: encoder 
Sequential(
  (enc-0): Linear(in_features=7776, out_features=128, bias=True)
  (act-0): Tanh()
  (enc-1): Linear(in_features=128, out_features=64, bias=True)
  (act-1): Tanh()
)
#
mu_layer: 
Linear(in_features=64, out_features=32, bias=True)
#
sigma_layer: 
Linear(in_features=64, out_features=32, bias=True)
#
U: decoder 
Sequential(
  (-dec-0): Linear(in_features=32, out_features=64, bias=True)
  (-act-0): Tanh()
  (dec-0): Linear(in_features=64, out_features=128, bias=True)
  (act-0): Tanh()
  (dec-1): Linear(in_features=128, out_features=7776, bias=True)
  (act-1): Tanh()
)
U: encoder 
Sequential(
  (enc-0): Linear(in_features=7776, out_features=128, bias=True)
  (act-0): Tanh()
  (enc-1): Linear(in_features=128, out_features=64, bias=True)
  (act-1): Tanh()
)
#
mu_layer: 
Linear(in_features=64, out_features=32, bias=True)
#
sigma_layer: 
Linear(in_features=64, out_features=32, bias=True)
#
U: decoder 
Sequential(
  (-dec-0): Linear(in_features=32, out_features=64, bias=T

In [35]:
from torch.utils.data import DataLoader, Dataset

In [36]:
class CustomDataset(Dataset):
    def __init__(self, train_features):
        self.train_features = train_features

    def __len__(self):
        return len(self.train_features)

    def __getitem__(self, idx):
        return torch.Tensor(self.train_features[idx][:-1]), self.train_features[idx][-1]

In [37]:
cl_training_data = CustomDataset(cl_train_features)

In [38]:
cl_train_dataloader = DataLoader(cl_training_data, batch_size=batch_size, shuffle=True, generator=g, worker_init_fn=seed_worker)

In [39]:
cl_train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x7f4edd316f40>

In [40]:
cl_test_data = CustomDataset(cl_test_features)
cl_test_dataloader = DataLoader(cl_test_data, batch_size=batch_size, shuffle=False, generator=g, worker_init_fn=seed_worker)

In [41]:
tcga_training_data = CustomDataset(tcga_train_features)

In [42]:
tcga_train_dataloader = DataLoader(tcga_training_data, batch_size=batch_size, shuffle=True, generator=g, worker_init_fn=seed_worker)

In [43]:
tcga_train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x7f4edd331a30>

In [44]:
tcga_test_data = CustomDataset(tcga_test_features)
tcga_test_dataloader = DataLoader(tcga_test_data, batch_size=batch_size, shuffle=False, generator=g, worker_init_fn=seed_worker)

In [45]:
train_loaders = {
    "RECIST_prediction": tcga_train_dataloader,
    "AUDRC_prediction": cl_train_dataloader
}

In [46]:
test_loaders = {
    "RECIST_prediction": tcga_test_dataloader,
    "AUDRC_prediction": cl_test_dataloader
}

In [47]:
datasets = {
                'RECIST_prediction': ['tcga'],
                'AUDRC_prediction': ['ccle'],
}

In [48]:
# maximum number of iterations considering all datasets
min_iterations = {task: [len(loader) * args.epochs for loader in loaders]
                  for task, loaders in train_loaders.items()}
max_iterations = max([max(iters) for iters in min_iterations.values()])

# with open('models/cl_ids.json', 'r') as f:
#     test_cl_ids = json.load(f)

print(f'# of iterations to run = {max_iterations}')
inv_preference = np.array(args.inv_preference)
preference = 1.0 / inv_preference
preference /= preference.sum()
intra_preference = {task: np.ones(len(datasets[task])) / len(datasets[task]) for task in tasks}
if args.moo == 'EPO':
    epo_ = EPO(3, inv_preference)
    epo_dt = EPO(3, np.array([1.,1.,1.]), eps=0.3)
    epo_dr = EPO(2, np.array([1., 1.]))
    epo_ds = EPO(2, np.array([1., 1.]))
    epos = {task: ep for task, ep in zip(tasks, [epo_dt, epo_dr, epo_ds])}

epochs_completed = {task: {ds: 0 for ds in task_datasets}
                    for task, task_datasets in datasets.items()}
train_iter_loaders = {task: [iter(loader) for loader in loaders]
                      for task, loaders in train_loaders.items()}

# of iterations to run = 1000


In [49]:
model

DruID(
  (drug_embedder): Sequential(
    (0): Linear(in_features=2048, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=32, bias=True)
  )
  (cell_line_embedder): CellLineEmbedder(
    (vae_model1): vae(
      (mu_layer): Linear(in_features=64, out_features=32, bias=True)
      (sigma_layer): Linear(in_features=64, out_features=32, bias=True)
      (encoder): Sequential(
        (enc-0): Linear(in_features=7776, out_features=128, bias=True)
        (act-0): Tanh()
        (enc-1): Linear(in_features=128, out_features=64, bias=True)
        (act-1): Tanh()
      )
      (decoder): Sequential(
        (-dec-0): Linear(in_features=32, out_features=64, bias=True)
        (-act-0): Tanh()
        (dec-0): Linear(in_features=64, out_features=128, bias=True)
        (act-0): Tanh()
        (dec-1): Linear(in_features=128, out_features=7776, bias=True)
        (act-1): Tanh()


### Prediction

In [50]:
# comment if training or load an existing model
model.load_state_dict(torch.load(f"{path_to_saved_model}/tcga_drug_fp_clinvar_gpd_annovar_annotated_{drug_name}_drug_trs_{wandb_name}_all_drugs.pth"))

<All keys matched successfully>

In [51]:
model.eval()
y_preds = []
y_true = []
for idx, (inp, y) in enumerate(tcga_test_dataloader):
    # drug input
    patient_drug_input = inp[:, 324*6*4:].to(device, torch.float32)
    # mutation profile
    patient_mut_input = inp[:, :324*6*4].to(device, torch.float32)

    # drug embedding
    patient_drug_emb = model.drug_embedder(patient_drug_input)

    # mutation embedding
    patient_mut_emb = model.patient_embedder(patient_mut_input)

    # concat and pass through prediction heads
    patient_drug_cat_emb = torch.cat((patient_mut_emb, patient_drug_emb), dim=1)

    recist_prediction = model.recist_predictor(patient_drug_cat_emb)
    y_preds.extend(list(recist_prediction.flatten().detach().cpu().numpy()))
    y_true.extend(y.detach().cpu().numpy())

### Metrics

In [52]:
from scipy import stats
from numpy import argmax
from sklearn.metrics import roc_curve

In [53]:
len(y_preds)

24

In [54]:
y_true = tcga_dataset_test.tcga_response[tcga_dataset_test.tcga_response.drug_name == drug_name].reset_index(drop=True)
y_pred = y_true.copy()
y_pred["response"] = y_preds
y_pred

Unnamed: 0,submitter_id,drug_name,response
0,TCGA-GM-A2DH,5-FLUOROURACIL,0.494099
1,TCGA-GM-A2DB,5-FLUOROURACIL,0.565825
2,TCGA-E9-A3X8,5-FLUOROURACIL,0.481947
3,TCGA-GM-A4E0,5-FLUOROURACIL,0.654374
4,TCGA-4H-AAAK,5-FLUOROURACIL,0.423122
5,TCGA-VS-A8EJ,5-FLUOROURACIL,0.628808
6,TCGA-AZ-6600,5-FLUOROURACIL,0.25235
7,TCGA-AA-3692,5-FLUOROURACIL,0.395591
8,TCGA-AA-3860,5-FLUOROURACIL,0.510079
9,TCGA-AA-3976,5-FLUOROURACIL,0.663701


In [55]:
y_true

Unnamed: 0,submitter_id,drug_name,response
0,TCGA-GM-A2DH,5-FLUOROURACIL,1
1,TCGA-GM-A2DB,5-FLUOROURACIL,1
2,TCGA-E9-A3X8,5-FLUOROURACIL,1
3,TCGA-GM-A4E0,5-FLUOROURACIL,1
4,TCGA-4H-AAAK,5-FLUOROURACIL,1
5,TCGA-VS-A8EJ,5-FLUOROURACIL,0
6,TCGA-AZ-6600,5-FLUOROURACIL,0
7,TCGA-AA-3692,5-FLUOROURACIL,0
8,TCGA-AA-3860,5-FLUOROURACIL,1
9,TCGA-AA-3976,5-FLUOROURACIL,1


In [56]:

y_pred_pivotted = y_pred.pivot_table(
                "response", "submitter_id", "drug_name"
            )
y_pred_pivotted = y_pred_pivotted.fillna(0) # in case there are NaNs
dict_idx_drug = pd.DataFrame(y_pred_pivotted.columns).to_dict()["drug_name"]
dict_id_drug = {}

for patient_id, predictions in y_pred_pivotted.iterrows():

    cur_pred_scores = predictions.values
    cur_recom_drug_idx = np.argsort(cur_pred_scores)[:-11:-1]
    #
    dict_recom_drug = {}
    for idx, cur_idx in enumerate(cur_recom_drug_idx):
        dict_recom_drug[
            dict_idx_drug[cur_idx]
        ] = f"{cur_pred_scores[cur_idx]} ({idx+1})"
    #
    dict_id_drug[patient_id] = dict_recom_drug

predictions_display_tcga = pd.DataFrame.from_dict(dict_id_drug)

na_mask = y_pred.response.isna()
if na_mask.sum():
    print(
        f"Found {na_mask.sum()} rows with invalid response values"
    )
    y_pred = y_pred[~na_mask]
    y_true = y_true.loc[~(na_mask.values)]
na_mask = y_true.response.isna()
y_true = y_true[~na_mask]
y_pred = y_pred[~na_mask]
print(y_pred.shape)
y_pred.head()
y_combined = y_pred.merge(y_true, on=["submitter_id", "drug_name"])

from sklearn.metrics import average_precision_score, ndcg_score, roc_auc_score, f1_score, accuracy_score, precision_score, recall_score

drugs_with_enough_support = [drug_name]

for drug_name in drugs_with_enough_support:
    try:
        roc = roc_auc_score(
            y_true[y_true.drug_name == drug_name].response.values,
            y_pred[y_pred.drug_name == drug_name].response.values,
            average="micro",
        )
        aupr = average_precision_score(
            y_true[y_true.drug_name == drug_name].response.values,
            y_pred[y_pred.drug_name == drug_name].response.values,
            average="micro",
        )
        # Choosing the right threshold for F1, accuracy and precision calculation from ref: https://machinelearningmastery.com/threshold-moving-for-imbalanced-classification/
        fpr, tpr, thresholds = roc_curve(
            y_true[y_true.drug_name == drug_name].response.values,
            y_pred[y_pred.drug_name == drug_name].response.values,
        )
        J = tpr - fpr
        ix = argmax(J)
        best_thresh = thresholds[ix]
        
        f1 = f1_score(
            y_true[y_true.drug_name == drug_name].response.values,
            (y_pred[y_pred.drug_name == drug_name].response.values > best_thresh).astype(int),
        )
        acc_score = accuracy_score(
            y_true[y_true.drug_name == drug_name].response.values,
            (y_pred[y_pred.drug_name == drug_name].response.values > best_thresh).astype(int),
        )
        prec_score = precision_score(
            y_true[y_true.drug_name == drug_name].response.values,
            (y_pred[y_pred.drug_name == drug_name].response.values > best_thresh).astype(int),
        )
        rec_score = recall_score(
            y_true[y_true.drug_name == drug_name].response.values,
            (y_pred[y_pred.drug_name == drug_name].response.values > best_thresh).astype(int),
        )
        print(f"AUROC for {drug_name}: {roc}")
        print(f"AUPR for {drug_name}: {aupr}")
        print(f"F1 for {drug_name}: {f1}")
        print(f"Accuracy Score for {drug_name}: {acc_score}")
        print(f"Precision Score for {drug_name}: {prec_score}")
        print(f"Recall Score for {drug_name}: {rec_score}")
    except Exception as e:
        print(f"Error processing {drug_name} - {e}")



(24, 3)
AUROC for 5-FLUOROURACIL: 0.7473684210526316
AUPR for 5-FLUOROURACIL: 0.9093652324698331
F1 for 5-FLUOROURACIL: 0.8235294117647058
Accuracy Score for 5-FLUOROURACIL: 0.75
Precision Score for 5-FLUOROURACIL: 0.9333333333333333
Recall Score for 5-FLUOROURACIL: 0.7368421052631579
