Uses Morgan fingerprints, VAE repr for cell lines and patients, for ClinVar, GPD and Annovar based annotations, to perform inference from a pretrained DruID model.

### Load libraries

In [1]:
import sys

sys.path.append("../src/")

In [2]:
import numpy as np
import pandas as pd

import torch

from torch import nn
from torch.nn import functional as F

from functools import cached_property

from torch.nn import Linear, ReLU, Sequential

from sklearn.metrics import average_precision_score, ndcg_score, roc_auc_score
from torch.utils.data import Dataset,DataLoader
import torch.nn as nn
import torch.optim as optim

from datasets_drug_filtered import (
    AggCategoricalAnnotatedCellLineDatasetFilteredByDrug,
    AggCategoricalAnnotatedTcgaDatasetFilteredByDrug,
)

from utils import get_kld_loss, get_zinb_loss

from seaborn import scatterplot

from sklearn.metrics import pairwise_distances

In [3]:
# To avoid randomness in DataLoaders - https://pytorch.org/docs/stable/notes/randomness.html
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)
    
g = torch.Generator()
g.manual_seed(0)

<torch._C.Generator at 0x7f29867cf530>

In [4]:
sample_id = 0

In [5]:
# wand_run_2_load = "super-thunder-9" # name of the wandb run to load (pretrained model)
# drug_name = "CISPLATIN" # drug name to filter by

### Model Definition

In [6]:
from vae import vae

In [7]:
class CellLineEmbedder(nn.Module):
    @cached_property
    def device(self):
        return torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    
    def __init__(
        self,
        checkpoint_base_path="../data/model_checkpoints",
    ):
        super(CellLineEmbedder, self).__init__()
        self.checkpoint_base_path = checkpoint_base_path

        input_dim_vae = 324 * 6 * 4
        k_list = [128, 64]
        actf_list = ["tanh", "tanh"]
        is_real = True

        # The below modules are expected to be available in the scope where this module is instialized

        self.vae_model1 = vae(input_dim_vae, k_list, actf_list, is_real)
        self.vae_model1 = (
            self.vae_model1.cuda() if self.device.type == "cuda" else self.vae_model1
        )

        self.vae_model2 = vae(input_dim_vae, k_list, actf_list, is_real)
        self.vae_model2 = (
            self.vae_model2.cuda() if self.device.type == "cuda" else self.vae_model2
        )

        self.vae_model1_raw_mutation = vae(input_dim_vae, k_list, actf_list, is_real)
        self.vae_model1_raw_mutation = (
            self.vae_model1_raw_mutation.cuda()
            if self.device.type == "cuda"
            else self.vae_model1_raw_mutation
        )

        self.vae_model2_raw_mutation = vae(input_dim_vae, k_list, actf_list, is_real)
        self.vae_model2_raw_mutation = (
            self.vae_model2_raw_mutation.cuda()
            if self.device.type == "cuda"
            else self.vae_model2_raw_mutation
        )

    def __str__(self):
        return f"CellLineEmbedder"

    def load_model(self):
        
        self.vae_model1_raw_mutation.load_state_dict(
            torch.load(
                f"{self.checkpoint_base_path}/druid_with_tcga_filtered_drug_sample{sample_id}_unsupervised_vae_model_cell_line_domain_clinvar_gpd_annovar_annotated_v4.pt",
                map_location=str(self.device),
            )
        )
        
        self.vae_model2_raw_mutation.load_state_dict(
            torch.load(
                f"{self.checkpoint_base_path}/druid_with_tcga_filtered_drug_sample{sample_id}_unsupervised_vae_model_other_domain_clinvar_gpd_annovar_annotated_v4.pt",
                map_location=str(self.device),
            )
        )
        
        
    def forward(self, x):
        # Get cell line representation from annotated encoder
        _, cell_line_emb, _, _ = self.vae_model1_raw_mutation(x)
        return cell_line_emb

In [8]:
class PatientEmbedder(nn.Module):
    @cached_property
    def device(self):
        return torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    
    def __init__(
        self,
        checkpoint_base_path="../data/model_checkpoints",
    ):
        super(PatientEmbedder, self).__init__()
        self.checkpoint_base_path = checkpoint_base_path

        input_dim_vae = 324 * 6 * 4
        k_list = [128, 64]
        actf_list = ["tanh", "tanh"]
        is_real = True

        # The below modules are expected to be available in the scope where this module is instialized

        self.vae_model1 = vae(input_dim_vae, k_list, actf_list, is_real)
        self.vae_model1 = (
            self.vae_model1.cuda() if self.device.type == "cuda" else self.vae_model1
        )

        self.vae_model2 = vae(input_dim_vae, k_list, actf_list, is_real)
        self.vae_model2 = (
            self.vae_model2.cuda() if self.device.type == "cuda" else self.vae_model2
        )

        self.vae_model1_raw_mutation = vae(input_dim_vae, k_list, actf_list, is_real)
        self.vae_model1_raw_mutation = (
            self.vae_model1_raw_mutation.cuda()
            if self.device.type == "cuda"
            else self.vae_model1_raw_mutation
        )

        self.vae_model2_raw_mutation = vae(input_dim_vae, k_list, actf_list, is_real)
        self.vae_model2_raw_mutation = (
            self.vae_model2_raw_mutation.cuda()
            if self.device.type == "cuda"
            else self.vae_model2_raw_mutation
        )

    def __str__(self):
        return f"PatientEmbedder"

    def load_model(self):
        
        self.vae_model1_raw_mutation.load_state_dict(
            torch.load(
                f"{self.checkpoint_base_path}/druid_with_tcga_filtered_drug_sample{sample_id}_unsupervised_vae_model_cell_line_domain_clinvar_gpd_annovar_annotated_v4.pt",
                map_location=str(self.device),
            )
        )
        
        self.vae_model2_raw_mutation.load_state_dict(
            torch.load(
                f"{self.checkpoint_base_path}/druid_with_tcga_filtered_drug_sample{sample_id}_unsupervised_vae_model_other_domain_clinvar_gpd_annovar_annotated_v4.pt",
                map_location=str(self.device),
            )
        )
        
        
    def forward(self, x):
        # Get patient representation from annotated encoder
        _, patient_emb, _, _ = self.vae_model2_raw_mutation(x)
        return patient_emb

In [9]:
# drug_names = [drug_name]
drug_names = ['5-FLUOROURACIL', 'CYCLOPHOSPHAMIDE', 'DOCETAXEL', 'GEMCITABINE', 'CISPLATIN', 'PACLITAXEL']
uniq_drug_names = np.unique(np.array(drug_names))
drug_names_to_idx_map = dict(zip(uniq_drug_names, range(len(uniq_drug_names))))

In [10]:
drug_fp = pd.read_csv("../data/processed/drug_morgan_fingerprints.csv", index_col=0)
drug_fp

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,2038,2039,2040,2041,2042,2043,2044,2045,2046,2047
drug_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
JW-7-24-1,0,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
KIN001-260,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
NSC-87877,0,0,0,0,0,0,0,0,0,0,...,1,0,0,0,0,0,0,0,0,0
GNE-317,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
NAVITOCLAX,0,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
LGH447,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
TRASTUZUMAB,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
WNT974,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
TRIFLURIDINE,0,0,0,0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


### Creating the datasets

In [11]:
# cl_dataset_train = AggCategoricalAnnotatedCellLineDatasetFilteredByDrug(is_train=True, filter_for="tcga", sample_id=sample_id)
# cl_dataset_train.y_df

In [12]:
# cl_train_features = []
# cl_train_y = []
# for idx, row in cl_dataset_train.y_df.iterrows():
#     if row["drug_name"] == drug_name:
#         row_inp = []
#         row_inp.extend(cl_dataset_train.clinvar_gpd_annovar_annotated.loc[row["depmap_id"]].values)
#         row_inp.extend(drug_fp.loc[row["drug_name"]].values)
#         row_inp.append(row["auc"])
#         cl_train_y.append(row["auc"])
#         cl_train_features.append(row_inp)

In [13]:
# len(row_inp)

In [14]:
# cl_dataset_test = AggCategoricalAnnotatedCellLineDatasetFilteredByDrug(is_train=False, filter_for="tcga", sample_id=sample_id)
# cl_dataset_test.y_df

In [15]:
# cl_test_features = []
# cl_test_y = []
# for idx, row in cl_dataset_test.y_df.iterrows():
#     if row["drug_name"] == drug_name:
#         row_inp = []
#         row_inp.extend(cl_dataset_test.clinvar_gpd_annovar_annotated.loc[row["depmap_id"]].values)
#         row_inp.extend(drug_fp.loc[row["drug_name"]].values)
#         row_inp.append(row["auc"])
#         cl_test_y.append(row["auc"])
#         cl_test_features.append(row_inp)

In [16]:
# len(row_inp) # 324 gene mutations + 2048 len fingerprint + 1 AUDRC

In [17]:
# tcga_dataset_train = AggCategoricalAnnotatedTcgaDatasetFilteredByDrug(is_train=True, filter_for="tcga", sample_id=sample_id)
# tcga_dataset_train.tcga_response

In [18]:
# tcga_train_features = []
# tcga_train_y = []
# for idx, row in tcga_dataset_train.tcga_response.iterrows():
#     if row["drug_name"] == drug_name:
#         row_inp = []
#         row_inp.extend(tcga_dataset_train.clinvar_gpd_annovar_annotated.loc[row["submitter_id"]].values)
#         row_inp.extend(drug_fp.loc[row["drug_name"]].values)
#         row_inp.append(row["response"])
#         tcga_train_y.append(row["response"])
#         tcga_train_features.append(row_inp)

In [19]:
# len(row_inp)

In [20]:
tcga_dataset_test = AggCategoricalAnnotatedTcgaDatasetFilteredByDrug(is_train=False, filter_for="tcga", sample_id=sample_id)
tcga_dataset_test.tcga_response

Unnamed: 0,submitter_id,drug_name,response
3,TCGA-G2-A2EF,CISPLATIN,1
46,TCGA-FD-A5C1,CISPLATIN,1
47,TCGA-FD-A5C1,GEMCITABINE,1
57,TCGA-E5-A4U1,DOCETAXEL,0
58,TCGA-FD-A6TC,GEMCITABINE,1
...,...,...,...
594,TCGA-FI-A2D5,PACLITAXEL,0
604,TCGA-AJ-A3EK,PACLITAXEL,1
607,TCGA-AJ-A23N,PACLITAXEL,0
615,TCGA-EY-A3L3,PACLITAXEL,1


In [21]:
tcga_test_features = []
tcga_test_y = []
for idx, row in tcga_dataset_test.tcga_response.iterrows():
    row_inp = []
    row_inp.extend(tcga_dataset_test.clinvar_gpd_annovar_annotated.loc[row["submitter_id"]].values)
    row_inp.extend(drug_fp.loc[row["drug_name"]].values)
    row_inp.append(row["response"])
    tcga_test_y.append(row["response"])
    tcga_test_features.append(row_inp)

In [22]:
len(row_inp)

9825

In [23]:
from torch.utils.data import DataLoader, Dataset

In [24]:
# load basic set of params
from dotmap import DotMap
import yaml
import wandb
with open(f'./config/config_tcga_sample{sample_id}.yml', 'r') as f:
    args = DotMap(yaml.safe_load(f))
print(args)
np.random.seed(args.seed)

DotMap(model='DruID', batch_size=256, epochs=100, moo='CS', inv_preference=[0.1, 1], seed=0, device=0, valid_interval=100)


In [25]:
seed = args.seed

In [26]:
# finetuned params
lr_main_optim = 1e-5
lr_cl_optim = 1e-4
lr_patient_optim = 1e-3
lr_drug_optim = 1e-4
args.lr_main_optim = lr_main_optim
args.lr_cl_optim = lr_cl_optim
args.lr_patient_optim = lr_patient_optim
args.lr_drug_optim = lr_drug_optim
args.device = 1
args.epochs = 50
batch_size = args.batch_size

In [27]:
args

DotMap(model='DruID', batch_size=256, epochs=50, moo='CS', inv_preference=[0.1, 1], seed=0, device=1, valid_interval=100, lr_main_optim=1e-05, lr_cl_optim=0.0001, lr_patient_optim=0.001, lr_drug_optim=0.0001, _ipython_display_=DotMap(), _repr_mimebundle_=DotMap())

In [28]:
class CustomDataset(Dataset):
    def __init__(self, train_features):
        self.train_features = train_features

    def __len__(self):
        return len(self.train_features)

    def __getitem__(self, idx):
        return torch.Tensor(self.train_features[idx][:-1]), self.train_features[idx][-1]

In [29]:
# cl_training_data = CustomDataset(cl_train_features)

In [30]:
# cl_train_dataloader = DataLoader(cl_training_data, batch_size=batch_size, shuffle=True, generator=g, worker_init_fn=seed_worker)

In [31]:
# cl_train_dataloader

In [32]:
# cl_test_data = CustomDataset(cl_test_features)
# cl_test_dataloader = DataLoader(cl_test_data, batch_size=batch_size, shuffle=False, generator=g, worker_init_fn=seed_worker)

In [33]:
# tcga_training_data = CustomDataset(tcga_train_features)

In [34]:
# tcga_train_dataloader = DataLoader(tcga_training_data, batch_size=batch_size, shuffle=True, generator=g, worker_init_fn=seed_worker)

In [35]:
# tcga_train_dataloader

In [36]:
tcga_test_data = CustomDataset(tcga_test_features)
tcga_test_dataloader = DataLoader(tcga_test_data, batch_size=batch_size, shuffle=False, generator=g, worker_init_fn=seed_worker)

In [37]:
len(tcga_test_features)

126

In [38]:
# train_loaders = {
#     "RECIST_prediction": tcga_train_dataloader,
#     "AUDRC_prediction": cl_train_dataloader
# }

In [39]:
test_loaders = {
    "RECIST_prediction": tcga_test_dataloader,
    # "AUDRC_prediction": cl_test_dataloader
}

In [40]:
datasets = {
                'RECIST_prediction': ['tcga'],
                'AUDRC_prediction': ['ccle'],
}

### DruID Model

In [41]:
class DruID(nn.Module):
    '''
    Used for training 2 tasks - cell line-drug AUDRC prediction(regression) and patient-drug RECIST prediction(classification)
    300 dimensional input for drugs
    '''
    def __init__(self,single=False):
        super(DruID, self).__init__()
        # drug embedder network
        self.drug_embedder = self.fnn(2048, 128, 64, 32)
        # cell line embedder network
        self.cell_line_embedder = CellLineEmbedder(checkpoint_base_path=f'../data/model_checkpoints/')
        # self.cell_line_embedder.load_model() # load pretrained VAE model
        # patient embedder network
        self.patient_embedder = PatientEmbedder(checkpoint_base_path=f'../data/model_checkpoints/')
        # self.patient_embedder.load_model() # load pretrained VAE model
        # prediction heads
        self.recist_predictor = nn.Sequential(self.fnn(64, 64, 16, 1), ) # takes as input concatenated representation of patient and drug
        self.audrc_predictor = nn.Sequential(self.fnn(64, 64, 16, 1), ) # takes as input concatenated representation of cell line and drug

        self.AUDRC_specific = nn.ModuleDict({'embedder': self.cell_line_embedder,
                                              'predictor': self.audrc_predictor})
        self.RECIST_specific = nn.ModuleDict({'embedder': self.patient_embedder,
                                                'predictor': self.recist_predictor})

        self.name = 'DruID'
        self.device = torch.device(f'cuda:1' if torch.cuda.is_available() else 'cpu')
        
    def fnn(self, In, hidden1, hidden2, out):
        return nn.Sequential(nn.Linear(In, hidden1), nn.ReLU(), 
                             nn.Linear(hidden1, hidden2), nn.ReLU(), 
                             nn.Linear(hidden2, out))

    def forward(self,x1,x2): # x1 is patient data, x2 is cell lines - each row is of the form [mutations 324*24*6, drug fingerprint of 2048 dim]
        # input is of dim (batch_size, 7776+2048)
        # drug input
        patient_drug_input = x1[:, 324*6*4:].to(self.device, torch.float32)
        cl_drug_input = x2[:, 324*6*4:].to(self.device, torch.float32)
    
        # mutation profile
        patient_mut_input = torch.Tensor(x1[:,:324*6*4]).to(self.device, torch.float32)
        cl_mut_input = torch.Tensor(x2[:,:324*6*4]).to(self.device, torch.float32)
        
        # drug embedding
        patient_drug_emb = self.drug_embedder(patient_drug_input)
        cl_drug_emb = self.drug_embedder(cl_drug_input)
        
        # mutation embedding
        patient_mut_emb = self.patient_embedder(patient_mut_input)
        cl_mut_emb = self.cell_line_embedder(cl_mut_input)
        
        # concat and pass through prediction heads
        patient_drug_cat_emb = torch.cat((patient_mut_emb, patient_drug_emb), dim=1)
        cl_drug_cat_emb = torch.cat((cl_mut_emb, cl_drug_emb), dim=1)
        
        # prediction heads
        recist_prediction = self.recist_predictor(patient_drug_cat_emb)
        audrc_prediction = self.audrc_predictor(cl_drug_cat_emb)
        
        return recist_prediction, audrc_prediction


In [42]:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

tasks = ['RECIST_prediction', 'AUDRC_prediction']  # prediction tasks; model consumes in this order; important

# model related
model = eval(f'{args.model}()')
device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')
print(f'Using {device} device...')
model = model.to(device)

## Pretrained DruID model related:
# Uncomment below line if loading pretrained model weights for DruID (Eg: in TCGA, we load pretrained weights from IMAC-OV DruID model)
# model.load_state_dict(torch.load(f"/data/ajayago/druid/paper_intermediate_pretrained/model_checkpoints/druid_MTL_raw_mutations_sample{sample_id}/rad51_drug_fp_clinvar_gpd_annovar_annotated_drug_trs_{wand_run_2_load}_all_drugs.pth"))
# Uncomment below lines if freeze some layers like drug, cell line and patient embedder etc from pretrained model
# for param in model.drug_embedder.parameters():
#     param.requires_grad = False
# for param in model.cell_line_embedder.parameters():
#     param.requires_grad = False
# for param in model.patient_embedder.parameters():
#     param.requires_grad = False

specific_submodels = {
                      'RECIST_prediction': model.RECIST_specific,
                      'AUDRC_prediction': model.AUDRC_specific
                     }
common_submodel = model.drug_embedder

# optimization related

# optimizer_main = optim.SGD(list(model.audrc_predictor.parameters())+
#                             list(model.recist_predictor.parameters())
#                             , lr=lr_main_optim) # optimizer for prediction heads
# optimizer_drug = optim.Adam(model.drug_embedder.parameters(), lr = lr_drug_optim) # optimizer for drug embedder
# optimizer_cl = optim.Adam(model.cell_line_embedder.parameters(), lr=lr_cl_optim) # optimizer for cell line embedder
# optimizer_patient = optim.SGD(model.patient_embedder.parameters(), lr=lr_patient_optim) # optimizer for patient embedder

# # Loss terms for prediction heads
# criteria = {
#             'RECIST_prediction': nn.BCEWithLogitsLoss(),
#             'AUDRC_prediction': nn.MSELoss(),
#            }


U: encoder 
Sequential(
  (enc-0): Linear(in_features=7776, out_features=128, bias=True)
  (act-0): Tanh()
  (enc-1): Linear(in_features=128, out_features=64, bias=True)
  (act-1): Tanh()
)
#
mu_layer: 
Linear(in_features=64, out_features=32, bias=True)
#
sigma_layer: 
Linear(in_features=64, out_features=32, bias=True)
#
U: decoder 
Sequential(
  (-dec-0): Linear(in_features=32, out_features=64, bias=True)
  (-act-0): Tanh()
  (dec-0): Linear(in_features=64, out_features=128, bias=True)
  (act-0): Tanh()
  (dec-1): Linear(in_features=128, out_features=7776, bias=True)
  (act-1): Tanh()
)


RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

In [None]:
# # maximum number of iterations considering all datasets
# min_iterations = {task: [len(loader) * args.epochs for loader in loaders]
#                   for task, loaders in train_loaders.items()}
# max_iterations = max([max(iters) for iters in min_iterations.values()])

# print(f'# of iterations to run = {max_iterations}')
# inv_preference = np.array(args.inv_preference)
# preference = 1.0 / inv_preference
# preference /= preference.sum()
# intra_preference = {task: np.ones(len(datasets[task])) / len(datasets[task]) for task in tasks}

# epochs_completed = {task: {ds: 0 for ds in task_datasets}
#                     for task, task_datasets in datasets.items()}
# train_iter_loaders = {task: [iter(loader) for loader in loaders]
#                       for task, loaders in train_loaders.items()}

In [None]:
# # Training
# for i in range(max_iterations):
#     all_inputs = {task: [] for task in train_loaders}
#     all_targets = {task: [] for task in train_loaders}
#     for task, iter_loaders in train_loaders.items():
#         for inp, y in iter_loaders:
#             all_inputs[task].append(inp)
#             all_targets[task].append(y)
#     all_outputs = model(torch.cat(all_inputs["RECIST_prediction"]), torch.cat(all_inputs["AUDRC_prediction"]))
#     all_losses = {task: [criteria[task](all_outputs[tid], torch.cat(all_targets[task]).unsqueeze(1).to(device, torch.float32))
#                          ]
#                   for tid, (task, task_targets) in enumerate(all_targets.items())}
#     model.zero_grad()
#     optimizer_main.zero_grad()
#     optimizer_cl.zero_grad()
#     optimizer_patient.zero_grad()
#     optimizer_drug.zero_grad()
#     intra_coefs = {task: {ds: 0 for ds in datasets[task]} for task in tasks}
#     inter_coefs = {task: 0 for task in tasks}
#     if args.moo in ['CS']:
#         all_rel_losses = {task: [ds_loss * intra_preference[task][did] * preference[tid]
#                                  for did, ds_loss in enumerate(task_losses)]
#                           for tid, (task, task_losses) in enumerate(all_losses.items())}
#         if args.moo == 'CS':
#             total_loss = max([max(task_rel_losses)
#                               for task_rel_losses in all_rel_losses.values()])
#             print(total_loss)
#             max_tid, max_rel_loss = None, -1
#             for tid, (task, task_rel_losses) in enumerate(all_rel_losses.items()):
#                 max_did = max(range(len(task_rel_losses)), key=lambda lid: task_rel_losses[lid])
#                 intra_coefs[task][datasets[task][max_did]] = intra_preference[task][max_did]
#                 if task_rel_losses[max_did] > max_rel_loss:
#                     max_rel_loss = task_rel_losses[max_did]
#                     max_tid = tid
#             inter_coefs[tasks[max_tid]] = preference[max_tid]
#         else:
#             st_id = np.argmax(preference)
#             total_loss = sum(all_rel_losses[tasks[st_id]]) / preference[st_id]

#             inter_coefs[tasks[st_id]] = 1
#             for task in tasks:
#                 for did, ds in enumerate(datasets[task]):
#                     intra_coefs[task][ds] = intra_preference[task][did]
#         total_loss.backward()
#     else:
#         raise NotImplementedError('Choose an moo method')

#     optimizer_main.step()
#     if i > 10:
#         optimizer_cl.step()
#         optimizer_patient.step()
#         optimizer_drug.step()
#     log_losses = {task: {ds: loss.item() for ds, loss in zip(datasets[task], all_losses[task])}
#                   for task in tasks}

### Inference

In [None]:
# comment if training
model.load_state_dict(torch.load(f"../data/model_checkpoints/DruID.pth"))

NameError: name 'model' is not defined

In [None]:
model

DruID(
  (drug_embedder): Sequential(
    (0): Linear(in_features=2048, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=32, bias=True)
  )
  (cell_line_embedder): CellLineEmbedder(
    (vae_model1): vae(
      (mu_layer): Linear(in_features=64, out_features=32, bias=True)
      (sigma_layer): Linear(in_features=64, out_features=32, bias=True)
      (encoder): Sequential(
        (enc-0): Linear(in_features=7776, out_features=128, bias=True)
        (act-0): Tanh()
        (enc-1): Linear(in_features=128, out_features=64, bias=True)
        (act-1): Tanh()
      )
      (decoder): Sequential(
        (-dec-0): Linear(in_features=32, out_features=64, bias=True)
        (-act-0): Tanh()
        (dec-0): Linear(in_features=64, out_features=128, bias=True)
        (act-0): Tanh()
        (dec-1): Linear(in_features=128, out_features=7776, bias=True)
        (act-1): Tanh()


In [None]:
model.eval()
y_preds = []
for idx, (inp, y) in enumerate(tcga_test_dataloader):
    # drug input
    patient_drug_input = inp[:, 324*6*4:].to(device, torch.float32)
    # mutation profile
    patient_mut_input = inp[:, :324*6*4].to(device, torch.float32)

    # drug embedding
    patient_drug_emb = model.drug_embedder(patient_drug_input)

    # mutation embedding
    patient_mut_emb = model.patient_embedder(patient_mut_input)

    # concat and pass through prediction heads
    patient_drug_cat_emb = torch.cat((patient_mut_emb, patient_drug_emb), dim=1)

    recist_prediction = model.recist_predictor(patient_drug_cat_emb)
    y_preds.extend(list(recist_prediction.flatten().detach().cpu().numpy()))

### Metrics

In [None]:
from scipy import stats
from numpy import argmax
from sklearn.metrics import roc_curve

In [None]:
len(y_preds)

126

In [None]:
y_true = tcga_dataset_test.tcga_response[tcga_dataset_test.tcga_response.drug_name == "CISPLATIN"]
y_pred = tcga_dataset_test.tcga_response.copy()
y_pred["response"] = y_preds
y_pred

Unnamed: 0,submitter_id,drug_name,response
3,TCGA-G2-A2EF,CISPLATIN,0.129749
46,TCGA-FD-A5C1,CISPLATIN,0.149772
47,TCGA-FD-A5C1,GEMCITABINE,-0.431342
57,TCGA-E5-A4U1,DOCETAXEL,0.682843
58,TCGA-FD-A6TC,GEMCITABINE,-0.430205
...,...,...,...
594,TCGA-FI-A2D5,PACLITAXEL,-0.244540
604,TCGA-AJ-A3EK,PACLITAXEL,-0.409905
607,TCGA-AJ-A23N,PACLITAXEL,0.659928
615,TCGA-EY-A3L3,PACLITAXEL,0.663610


In [None]:

y_pred_pivotted = y_pred.pivot_table(
                "response", "submitter_id", "drug_name"
            )
y_pred_pivotted = y_pred_pivotted.fillna(0) # in case there are NaNs
dict_idx_drug = pd.DataFrame(y_pred_pivotted.columns).to_dict()["drug_name"]
dict_id_drug = {}

for patient_id, predictions in y_pred_pivotted.iterrows():

    cur_pred_scores = predictions.values
    cur_recom_drug_idx = np.argsort(cur_pred_scores)[:-11:-1]
    #
    dict_recom_drug = {}
    for idx, cur_idx in enumerate(cur_recom_drug_idx):
        dict_recom_drug[
            dict_idx_drug[cur_idx]
        ] = f"{cur_pred_scores[cur_idx]} ({idx+1})"
    #
    dict_id_drug[patient_id] = dict_recom_drug

predictions_display_tcga = pd.DataFrame.from_dict(dict_id_drug)

na_mask = y_pred.response.isna()
if na_mask.sum():
    print(
        f"Found {na_mask.sum()} rows with invalid response values"
    )
    y_pred = y_pred[~na_mask]
    y_true = y_true.loc[~(na_mask.values)]
na_mask = y_true.response.isna()
y_true = y_true[~na_mask]
y_pred = y_pred[~na_mask]
print(y_pred.shape)
y_pred.head()
y_combined = y_pred.merge(y_true, on=["submitter_id", "drug_name"])

from sklearn.metrics import average_precision_score, ndcg_score, roc_auc_score, f1_score, accuracy_score, precision_score, recall_score

drugs_with_enough_support = ["CISPLATIN"]

for drug_name in drugs_with_enough_support:
    try:
        roc = roc_auc_score(
            y_true[y_true.drug_name == drug_name].response.values,
            y_pred[y_pred.drug_name == drug_name].response.values,
            average="micro",
        )
        aupr = average_precision_score(
            y_true[y_true.drug_name == drug_name].response.values,
            y_pred[y_pred.drug_name == drug_name].response.values,
            average="micro",
        )
        # Choosing the right threshold for F1, accuracy and precision calculation from ref: https://machinelearningmastery.com/threshold-moving-for-imbalanced-classification/
        fpr, tpr, thresholds = roc_curve(
            y_true[y_true.drug_name == drug_name].response.values,
            y_pred[y_pred.drug_name == drug_name].response.values,
        )
        J = tpr - fpr
        ix = argmax(J)
        best_thresh = thresholds[ix]
        
        f1 = f1_score(
            y_true[y_true.drug_name == drug_name].response.values,
            (y_pred[y_pred.drug_name == drug_name].response.values > best_thresh).astype(int),
        )
        acc_score = accuracy_score(
            y_true[y_true.drug_name == drug_name].response.values,
            (y_pred[y_pred.drug_name == drug_name].response.values > best_thresh).astype(int),
        )
        prec_score = precision_score(
            y_true[y_true.drug_name == drug_name].response.values,
            (y_pred[y_pred.drug_name == drug_name].response.values > best_thresh).astype(int),
        )
        rec_score = recall_score(
            y_true[y_true.drug_name == drug_name].response.values,
            (y_pred[y_pred.drug_name == drug_name].response.values > best_thresh).astype(int),
        )
        spearman_stats = stats.spearmanr(
            y_true[y_true.drug_name == drug_name].response.values,
            y_pred[y_pred.drug_name == drug_name].response.values,
        )
        mw_stats = stats.mannwhitneyu(
            y_combined[
                (y_combined.drug_name == drug_name) & (y_combined.response_y == 0)
            ].response_x.values,
            y_combined[
                (y_combined.drug_name == drug_name) & (y_combined.response_y == 1)
            ].response_x.values,
            alternative="greater",
        )
        denominator = (
            y_combined[
                (y_combined.drug_name == drug_name) & (y_combined.response_y == 0)
            ].shape[0]
            * y_combined[
                (y_combined.drug_name == drug_name) & (y_combined.response_y == 1)
            ].shape[0]
        )
        print(f"AUROC for {drug_name}: {roc}")
        print(f"AUPR for {drug_name}: {aupr}")
        print(f"F1 for {drug_name}: {f1}")
        print(f"Accuracy Score for {drug_name}: {acc_score}")
        print(f"Precision Score for {drug_name}: {prec_score}")
        print(f"Recall Score for {drug_name}: {rec_score}")
        print(
            f"Spearman for {drug_name}: {round(spearman_stats.correlation, 4)} (p-val: {round(spearman_stats.pvalue, 4)})"
        )
        print(
            f"Mann-Whitney for {drug_name}: {round(mw_stats.statistic/denominator, 4)} (p-val: {round(mw_stats.pvalue, 4)})"
        )
    except Exception as e:
        print(f"Error processing {drug_name} - {e}")



  y_pred = y_pred[~na_mask]


IndexingError: Unalignable boolean Series provided as indexer (index of the boolean Series and of the indexed object do not match).