Uses Morgan fingerprints, VAE repr for cell lines and patients, for cnv.

In [None]:
import sys

sys.path.append("../src/")

In [None]:
import numpy as np
import pandas as pd

import datetime
import logging
import os
import time
import torch

from torch import nn
from torch.nn import functional as F

from functools import cached_property

from torch.nn import Linear, ReLU, Sequential

from sklearn.metrics import average_precision_score, ndcg_score, roc_auc_score
from torch.utils.data import Dataset,DataLoader
import torch.nn as nn
import torch.optim as optim


from datasets_drug_filtered import (
    AggCategoricalAnnotatedCellLineDatasetFilteredByDrug,
    AggCategoricalAnnotatedTcgaDatasetFilteredByDrug,
    
)

from utils import get_kld_loss, get_zinb_loss

from seaborn import scatterplot

from sklearn.metrics import pairwise_distances

In [None]:
# To avoid randomness in DataLoaders - https://pytorch.org/docs/stable/notes/randomness.html
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)
    
g = torch.Generator()
g.manual_seed(0)

In [None]:
sample_id = 0

### Model Definition

In [None]:
from model import (
    BaseDruidModel,
)

In [None]:
from ffnzinb import ffnzinb
from vae import vae

In [None]:
class CellLineEmbedder(nn.Module):
    @cached_property
    def device(self):
        return torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    
    def __init__(
        self,
        checkpoint_base_path="../data/model_checkpoints",
    ):
        super(CellLineEmbedder, self).__init__()
        self.checkpoint_base_path = checkpoint_base_path

        input_dim_vae = 324 * 3
        k_list = [128, 16]
        actf_list = ["tanh", "tanh"]
        is_real = True

        # The below modules are expected to be available in the scope where this module is instialized

        self.vae_model1_raw_mutation = vae(input_dim_vae, k_list, actf_list, is_real)
        self.vae_model1_raw_mutation = (
            self.vae_model1_raw_mutation.cuda(device=self.device)
            if self.device.type == "cuda"
            else self.vae_model1_raw_mutation
        )

        self.vae_model2_raw_mutation = vae(input_dim_vae, k_list, actf_list, is_real)
        self.vae_model2_raw_mutation = (
            self.vae_model2_raw_mutation.cuda(device=self.device)
            if self.device.type == "cuda"
            else self.vae_model2_raw_mutation
        )

    def __str__(self):
        return f"CellLineEmbedder"

    def load_model(self):
        
        self.vae_model1_raw_mutation.load_state_dict(
            torch.load(
                f"{self.checkpoint_base_path}/unsupervised_vae_model_cell_line_domain_cnv.pt",
                map_location=str(self.device),
            )
        )
        
        self.vae_model2_raw_mutation.load_state_dict(
            torch.load(
                f"{self.checkpoint_base_path}/unsupervised_vae_model_other_domain_cnv.pt",
                map_location=str(self.device),
            )
        )
        
        
    def forward(self, x):
        # Get cell line representation from annotated encoder
        _, cell_line_emb, _, _ = self.vae_model1_raw_mutation(x)
        return cell_line_emb

In [None]:
class PatientEmbedder(nn.Module):
    @cached_property
    def device(self):
        return torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    
    def __init__(
        self,
        checkpoint_base_path="../data/model_checkpoints",
    ):
        super(PatientEmbedder, self).__init__()
        self.checkpoint_base_path = checkpoint_base_path

        input_dim_vae = 324 * 3
        k_list = [128, 16]
        actf_list = ["tanh", "tanh"]
        is_real = True

        # The below modules are expected to be available in the scope where this module is instialized

        self.vae_model1_raw_mutation = vae(input_dim_vae, k_list, actf_list, is_real)
        self.vae_model1_raw_mutation = (
            self.vae_model1_raw_mutation.cuda(device=self.device)
            if self.device.type == "cuda"
            else self.vae_model1_raw_mutation
        )

        self.vae_model2_raw_mutation = vae(input_dim_vae, k_list, actf_list, is_real)
        self.vae_model2_raw_mutation = (
            self.vae_model2_raw_mutation.cuda(device=self.device)
            if self.device.type == "cuda"
            else self.vae_model2_raw_mutation
        )

    def __str__(self):
        return f"PatientEmbedder"

    def load_model(self):
        
        self.vae_model1_raw_mutation.load_state_dict(
            torch.load(
                f"{self.checkpoint_base_path}/unsupervised_vae_model_cell_line_domain_cnv.pt",
                map_location=str(self.device),
            )
        )
        
        self.vae_model2_raw_mutation.load_state_dict(
            torch.load(
                f"{self.checkpoint_base_path}/unsupervised_vae_model_other_domain_cnv.pt",
                map_location=str(self.device),
            )
        )
        
        
    def forward(self, x):
        # Get patient representation from annotated encoder
        _, patient_emb, _, _ = self.vae_model2_raw_mutation(x)
        return patient_emb

In [None]:
drug_names = ['DOCETAXEL', 'GEMCITABINE', 'CISPLATIN', 'PACLITAXEL', '5-FLUOROURACIL', 'CYCLOPHOSPHAMIDE']
uniq_drug_names = np.unique(np.array(drug_names))
drug_names_to_idx_map = dict(zip(uniq_drug_names, range(len(uniq_drug_names))))


In [None]:
uniq_drug_names

In [None]:
drug_fp = pd.read_csv("../data/processed/drug_morgan_fingerprints.csv", index_col=0)
drug_fp

### Creating the datasets

In [None]:
cl_dataset_train = AggCategoricalAnnotatedCellLineDatasetFilteredByDrug(is_train=True, filter_for="tcga", sample_id=sample_id)
cl_dataset_train.y_df

In [None]:
cl_train_features = []
cl_train_y = []
for idx, row in cl_dataset_train.y_df.iterrows():
    row_inp = []
    row_inp.extend(cl_dataset_train.cnv.loc[row["depmap_id"]].values)
    row_inp.extend(drug_fp.loc[row["drug_name"]].values)
    row_inp.append(row["auc"])
    cl_train_y.append(row["auc"])
    cl_train_features.append(row_inp)

In [None]:
len(row_inp)

In [None]:
cl_dataset_test = AggCategoricalAnnotatedCellLineDatasetFilteredByDrug(is_train=False, filter_for="tcga", sample_id=sample_id)
cl_dataset_test.y_df

In [None]:
cl_test_features = []
cl_test_y = []
for idx, row in cl_dataset_test.y_df.iterrows():
    row_inp = []
    row_inp.extend(cl_dataset_test.cnv.loc[row["depmap_id"]].values)
    row_inp.extend(drug_fp.loc[row["drug_name"]].values)
    row_inp.append(row["auc"])
    cl_test_y.append(row["auc"])
    cl_test_features.append(row_inp)

In [None]:
len(row_inp) # 324 gene mutations + 2048 len fingerprint + 1 AUDRC

In [None]:
tcga_dataset_train = AggCategoricalAnnotatedTcgaDatasetFilteredByDrug(is_train=True, filter_for="tcga", sample_id=sample_id)
tcga_dataset_train.tcga_response

In [None]:
tcga_train_features = []
tcga_train_y = []
for idx, row in tcga_dataset_train.tcga_response.iterrows():
    row_inp = []
    row_inp.extend(tcga_dataset_train.cnv.loc[row["submitter_id"]].values)
    row_inp.extend(drug_fp.loc[row["drug_name"]].values)
    row_inp.append(row["response"])
    tcga_train_y.append(row["response"])
    tcga_train_features.append(row_inp)

In [None]:
len(row_inp)

In [None]:
tcga_dataset_test = AggCategoricalAnnotatedTcgaDatasetFilteredByDrug(is_train=False, filter_for="tcga", sample_id=sample_id)
tcga_dataset_test.tcga_response

In [None]:
tcga_test_features = []
tcga_test_y = []
for idx, row in tcga_dataset_test.tcga_response.iterrows():
    row_inp = []
    row_inp.extend(tcga_dataset_test.cnv.loc[row["submitter_id"]].values)
    row_inp.extend(drug_fp.loc[row["drug_name"]].values)
    row_inp.append(row["response"])
    tcga_test_y.append(row["response"])
    tcga_test_features.append(row_inp)

In [None]:
len(row_inp)

In [None]:
class DruID(nn.Module):
    '''
    Used for training 2 tasks - cell line-drug AUDRC prediction(regression) and patient-drug RECIST prediction(classification)
    300 dimensional input for drugs
    '''
    def __init__(self,single=False):
        super(DruID, self).__init__()
        self.drug_embedder = self.fnn(2048, 64, 16, 8)
        self.cell_line_embedder = CellLineEmbedder(checkpoint_base_path=f'/data/ajayago/druid/paper_intermediate//model_checkpoints/2B_druid_with_tcga_filtered_drug_sample{sample_id}/')
        self.cell_line_embedder.load_model()
        self.patient_embedder = PatientEmbedder(checkpoint_base_path=f'/data/ajayago/druid/paper_intermediate/model_checkpoints/2B_druid_with_tcga_filtered_drug_sample{sample_id}/')
        self.patient_embedder.load_model()
        self.recist_predictor = nn.Sequential(self.fnn(16, 64, 16, 1), ) # takes as input concatenated representation of cell line/patient and drug
        self.audrc_predictor = nn.Sequential(self.fnn(16, 64, 16, 1), )#nn.Sigmoid())

        self.AUDRC_specific = nn.ModuleDict({'embedder': self.cell_line_embedder,
                                              'predictor': self.audrc_predictor})
        self.RECIST_specific = nn.ModuleDict({'embedder': self.patient_embedder,
                                                'predictor': self.recist_predictor})

        self.name = 'DrugTRS - train AUDRC and RECIST together '
        drug_names = ['DOCETAXEL', 'GEMCITABINE', 'CISPLATIN', 'PACLITAXEL', '5-FLUOROURACIL', 'CYCLOPHOSPHAMIDE']
        uniq_drug_names = np.unique(np.array(drug_names))
        self.device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')
        
        
    def fnn(self, In, hidden1, hidden2, out):
        return nn.Sequential(nn.Linear(In, hidden1), nn.ReLU(), #nn.BatchNorm1d(hidden1),
                             nn.Linear(hidden1, hidden2), nn.ReLU(), #nn.BatchNorm1d(hidden2),
                             nn.Linear(hidden2, out))

    def forward(self,x1,x2): # x1 is Rad51, x2 is cell lines - each row is of the form [mutation 324, drug fp]
        # input is of dim (batch_size, 325)
        # drug input
        patient_drug_input = x1[:, 324*3:].to(self.device, torch.float32)
        cl_drug_input = x2[:, 324*3:].to(self.device, torch.float32)
    
        # mutation profile
        patient_mut_input = torch.Tensor(x1[:,:324*3]).to(self.device, torch.float32)
        cl_mut_input = torch.Tensor(x2[:,:324*3]).to(self.device, torch.float32)


        patient_drug_emb = self.drug_embedder(patient_drug_input)
        cl_drug_emb = self.drug_embedder(cl_drug_input)
        
        # mutation embedding
        patient_mut_emb = self.patient_embedder(patient_mut_input)
        cl_mut_emb = self.cell_line_embedder(cl_mut_input)
        
        # concat and pass through prediction heads
        patient_drug_cat_emb = torch.cat((patient_mut_emb, patient_drug_emb), dim=1)
        cl_drug_cat_emb = torch.cat((cl_mut_emb, cl_drug_emb), dim=1)
        
        recist_prediction = self.recist_predictor(patient_drug_cat_emb)
        audrc_prediction = self.audrc_predictor(cl_drug_cat_emb)
        
        return recist_prediction, audrc_prediction


In [None]:
from dotmap import DotMap
import yaml
import wandb
with open(f'../notebook/config/config_tcga_sample{sample_id}.yml', 'r') as f:
    args = DotMap(yaml.safe_load(f))
print(args)
np.random.seed(args.seed)

In [None]:
seed = args.seed

In [None]:
lr_main_optim = 1e-5
lr_cl_optim = 1e-6
lr_patient_optim = 1e-4
lr_drug_optim = 1e-4
args.lr_main_optim = lr_main_optim
args.lr_cl_optim = lr_cl_optim
args.lr_patient_optim = lr_patient_optim
args.lr_drug_optim = lr_drug_optim
args.epochs = 500
args.device = 1

In [None]:
args

In [None]:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

tasks = ['RECIST_prediction', 'AUDRC_prediction']  # prediction tasks; model consumes in this order; important

# model
model = eval(f'{args.model}()')
device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')
print(f'Using {device} device...')
model = model.to(device)
specific_submodels = {
                      'RECIST_prediction': model.RECIST_specific,
                      'AUDRC_prediction': model.AUDRC_specific
                     }
common_submodel = model.drug_embedder

# optimization related
batch_size = args.batch_size
# optimizer_main = optim.Adam(model.parameters(), lr = lr_main_optim)
optimizer_main = optim.Adam(list(model.audrc_predictor.parameters())+
                            list(model.recist_predictor.parameters())
                            , lr=lr_main_optim)  # , lr=1e-2)=
optimizer_drug = optim.Adam(model.drug_embedder.parameters(), lr = lr_drug_optim)
optimizer_cl = optim.Adam(model.cell_line_embedder.parameters(), lr=lr_cl_optim)
optimizer_patient = optim.Adam(model.patient_embedder.parameters(), lr=lr_patient_optim)
criteria = {
            'RECIST_prediction': nn.BCEWithLogitsLoss(),
            'AUDRC_prediction': nn.MSELoss(),
           }


In [None]:
from torch.utils.data import DataLoader, Dataset

In [None]:
class CustomDataset(Dataset):
    def __init__(self, train_features):
        self.train_features = train_features

    def __len__(self):
        return len(self.train_features)

    def __getitem__(self, idx):
        return torch.Tensor(self.train_features[idx][:-1]), self.train_features[idx][-1]

In [None]:
cl_training_data = CustomDataset(cl_train_features)

In [None]:
cl_train_dataloader = DataLoader(cl_training_data, batch_size=batch_size, shuffle=True, generator=g, worker_init_fn=seed_worker)

In [None]:
cl_train_dataloader

In [None]:
cl_test_data = CustomDataset(cl_test_features)
cl_test_dataloader = DataLoader(cl_test_data, batch_size=batch_size, shuffle=False, generator=g, worker_init_fn=seed_worker)

In [None]:
tcga_training_data = CustomDataset(tcga_train_features)

In [None]:
tcga_train_dataloader = DataLoader(tcga_training_data, batch_size=batch_size, shuffle=True, generator=g, worker_init_fn=seed_worker)

In [None]:
tcga_train_dataloader

In [None]:
tcga_test_data = CustomDataset(tcga_test_features)
tcga_test_dataloader = DataLoader(tcga_test_data, batch_size=batch_size, shuffle=False, generator=g, worker_init_fn=seed_worker)

In [None]:
train_loaders = {
    "RECIST_prediction": tcga_train_dataloader,
    "AUDRC_prediction": cl_train_dataloader
}

In [None]:
test_loaders = {
    "RECIST_prediction": tcga_test_dataloader,
    "AUDRC_prediction": cl_test_dataloader
}

In [None]:
datasets = {
                'RECIST_prediction': ['tcga'],
                'AUDRC_prediction': ['ccle'],
}

In [None]:
# maximum number of iterations considering all datasets
min_iterations = {task: [len(loader) * args.epochs for loader in loaders]
                  for task, loaders in train_loaders.items()}
max_iterations = max([max(iters) for iters in min_iterations.values()])

# with open('models/cl_ids.json', 'r') as f:
#     test_cl_ids = json.load(f)

print(f'# of iterations to run = {max_iterations}')
inv_preference = np.array(args.inv_preference)
preference = 1.0 / inv_preference
preference /= preference.sum()
intra_preference = {task: np.ones(len(datasets[task])) / len(datasets[task]) for task in tasks}
if args.moo == 'EPO':
    epo_ = EPO(3, inv_preference)
    epo_dt = EPO(3, np.array([1.,1.,1.]), eps=0.3)
    epo_dr = EPO(2, np.array([1., 1.]))
    epo_ds = EPO(2, np.array([1., 1.]))
    epos = {task: ep for task, ep in zip(tasks, [epo_dt, epo_dr, epo_ds])}

epochs_completed = {task: {ds: 0 for ds in task_datasets}
                    for task, task_datasets in datasets.items()}
train_iter_loaders = {task: [iter(loader) for loader in loaders]
                      for task, loaders in train_loaders.items()}

In [None]:
# Training
for i in range(max_iterations):
    all_inputs = {task: [] for task in train_loaders}
    all_targets = {task: [] for task in train_loaders}
    for task, iter_loaders in train_loaders.items():
        for inp, y in iter_loaders:
            all_inputs[task].append(inp)
            all_targets[task].append(y)
    all_outputs = model(torch.cat(all_inputs["RECIST_prediction"]), torch.cat(all_inputs["AUDRC_prediction"]))
    all_losses = {task: [criteria[task](all_outputs[tid], torch.cat(all_targets[task]).unsqueeze(1).to(device, torch.float32))
                         ]
                  for tid, (task, task_targets) in enumerate(all_targets.items())}
    model.zero_grad()
    optimizer_main.zero_grad()
    optimizer_cl.zero_grad()
    optimizer_patient.zero_grad()
    optimizer_drug.zero_grad()
    intra_coefs = {task: {ds: 0 for ds in datasets[task]} for task in tasks}
    inter_coefs = {task: 0 for task in tasks}
    if args.moo in ['LS', 'CS', 'ST']:
        all_rel_losses = {task: [ds_loss * intra_preference[task][did] * preference[tid]
                                 for did, ds_loss in enumerate(task_losses)]
                          for tid, (task, task_losses) in enumerate(all_losses.items())}
#         print(all_rel_losses)
        if args.moo == 'LS':
            total_loss = sum([sum(task_rel_losses)
                              for task_rel_losses in all_rel_losses.values()])

            for tid, task in enumerate(tasks):
                inter_coefs[task] = preference[tid]
                for did, ds in enumerate(datasets[task]):
                    intra_coefs[task][ds] = intra_preference[task][did]
        elif args.moo == 'CS':
            total_loss = max([max(task_rel_losses)
                              for task_rel_losses in all_rel_losses.values()])
#             total_loss = max(all_rel_losses["RECIST_prediction"]) # for single task
            print(total_loss)
            max_tid, max_rel_loss = None, -1
            for tid, (task, task_rel_losses) in enumerate(all_rel_losses.items()):
                max_did = max(range(len(task_rel_losses)), key=lambda lid: task_rel_losses[lid])
                intra_coefs[task][datasets[task][max_did]] = intra_preference[task][max_did]
                if task_rel_losses[max_did] > max_rel_loss:
                    max_rel_loss = task_rel_losses[max_did]
                    max_tid = tid
            inter_coefs[tasks[max_tid]] = preference[max_tid]
        else:
            st_id = np.argmax(preference)
            total_loss = sum(all_rel_losses[tasks[st_id]]) / preference[st_id]

            inter_coefs[tasks[st_id]] = 1
            for task in tasks:
                for did, ds in enumerate(datasets[task]):
                    intra_coefs[task][ds] = intra_preference[task][did]
        total_loss.backward()
    elif args.moo == 'EPO':
        n_tasks = len(tasks)
        shared_grads = [[] for _ in range(n_tasks)]  # NOTE: DO NOT USE [[]] * n_tasks.
        apparent_losses = np.zeros(n_tasks)
        for tid, (task, task_losses) in enumerate(all_losses.items()):
            n_ds = len(task_losses)  # number of datasets in the task
            specific_submodel = specific_submodels[task]
            specific_submodel_grads = [[] for _ in range(n_ds)]
            common_submodel_grads = [[] for _ in range(n_ds)]
            for did, ds_loss in enumerate(task_losses):
                ds_loss.backward()
                for param in specific_submodel.parameters():
                    specific_submodel_grads[did].append(param.grad.clone())
                for param in common_submodel.parameters():
                    common_submodel_grads[did].append(param.grad.clone())
                specific_submodel.zero_grad()
                common_submodel.zero_grad()
            GG = torch.zeros(n_ds, n_ds)
            for grads in [specific_submodel_grads, common_submodel_grads]:
                for j in range(n_ds):
                    for k in range(j, n_ds):
                        Gj_dot_Gk = sum([gj.flatten().dot(gk.flatten())
                                         for gj, gk in zip(grads[j], grads[k])]).cpu()
                        GG[j, k] += Gj_dot_Gk
                        GG[k, j] += Gj_dot_Gk

            l = np.array([ds_loss.item() for ds_loss in task_losses], dtype=np.double)
            print('intra-task epo:', task)
            beta = epos[task].get_beta(l, GG.numpy().astype(np.double))
            for pid, param in enumerate(specific_submodel.parameters()):
                param.grad = sum([beta[j] * specific_submodel_grads[j][pid]
                                  for j in range(n_ds)])

            for pid, _ in enumerate(common_submodel.parameters()):
                shared_grads[tid].append(sum([beta[j] * common_submodel_grads[j][pid]
                                              for j in range(n_ds)]))
            apparent_losses[tid] = l.dot(beta)

            for did, ds in enumerate(datasets[task]):
                intra_coefs[task][ds] = beta[did]

        GG = torch.zeros(n_tasks, n_tasks)
        for j in range(n_tasks):
            for k in range(j, n_tasks):
                GG[j, k] = sum([gj.flatten().dot(gk.flatten())
                                for gj, gk in zip(shared_grads[j], shared_grads[k])])
                GG[k, j] = GG[j, k]
        print('inter-task epo')
        beta = epo_.get_beta(apparent_losses, GG.numpy().astype(np.double))
        for pid, param in enumerate(common_submodel.parameters()):
            param.grad = sum([beta[j] * shared_grads[j][pid]
                              for j in range(n_tasks)])

        for tid, task in enumerate(tasks):
            for param in specific_submodels[task].parameters():
                param.grad *= beta[tid]

        for tid, task in enumerate(tasks):
            inter_coefs[task] = beta[tid]
    else:
        raise NotImplementedError('Choose an moo method')

    optimizer_main.step()
    if i > 10:
        optimizer_cl.step()
        optimizer_patient.step()
        optimizer_drug.step()
    log_losses = {task: {ds: loss.item() for ds, loss in zip(datasets[task], all_losses[task])}
                  for task in tasks}
    

In [None]:
for i in all_outputs[0]:
    print(i[0].cpu().detach().item())

In [None]:
model

### Prediction

In [None]:
model.eval()
y_preds = []
for idx, (inp, y) in enumerate(tcga_test_dataloader):
    # drug input
    patient_drug_input = inp[:, 324*3:].to(device, torch.float32)
    
    # mutation profile
    patient_mut_input = torch.Tensor(inp[:,:324*3]).to(model.device, torch.float32)

    # drug embedding
    patient_drug_emb = model.drug_embedder(patient_drug_input)

    # mutation embedding
    patient_mut_emb = model.patient_embedder(patient_mut_input)

    # concat and pass through prediction heads
    patient_drug_cat_emb = torch.cat((patient_mut_emb, patient_drug_emb), dim=1)

    recist_prediction = model.recist_predictor(patient_drug_cat_emb)
    y_preds.extend(list(recist_prediction.flatten().detach().cpu().numpy()))

### Metrics

In [None]:
from scipy import stats
from numpy import argmax
from sklearn.metrics import roc_curve

In [None]:
len(y_preds)

In [None]:
y_true = tcga_dataset_test.tcga_response
y_pred = tcga_dataset_test.tcga_response.copy()
y_pred["response"] = y_preds
y_pred

In [None]:

y_pred_pivotted = y_pred.pivot_table(
                "response", "submitter_id", "drug_name"
            )
y_pred_pivotted = y_pred_pivotted.fillna(0) # in case there are NaNs
dict_idx_drug = pd.DataFrame(y_pred_pivotted.columns).to_dict()["drug_name"]
dict_id_drug = {}

for patient_id, predictions in y_pred_pivotted.iterrows():

    cur_pred_scores = predictions.values
    cur_recom_drug_idx = np.argsort(cur_pred_scores)[:-11:-1]
    #
    dict_recom_drug = {}
    for idx, cur_idx in enumerate(cur_recom_drug_idx):
        dict_recom_drug[
            dict_idx_drug[cur_idx]
        ] = f"{cur_pred_scores[cur_idx]} ({idx+1})"
    #
    dict_id_drug[patient_id] = dict_recom_drug

predictions_display_tcga = pd.DataFrame.from_dict(dict_id_drug)

na_mask = y_pred.response.isna()
if na_mask.sum():
    print(
        f"Found {na_mask.sum()} rows with invalid response values"
    )
    y_pred = y_pred[~na_mask]
    y_true = y_true.loc[~(na_mask.values)]
na_mask = y_true.response.isna()
y_true = y_true[~na_mask]
y_pred = y_pred[~na_mask]
print(y_pred.shape)
y_pred.head()
y_combined = y_pred.merge(y_true, on=["submitter_id", "drug_name"])

from sklearn.metrics import average_precision_score, ndcg_score, roc_auc_score, f1_score, accuracy_score, precision_score, recall_score

drugs_with_enough_support = ["CISPLATIN", "PACLITAXEL", "5-FLUOROURACIL"]

for drug_name in drugs_with_enough_support:
    try:
        roc = roc_auc_score(
            y_true[y_true.drug_name == drug_name].response.values,
            y_pred[y_pred.drug_name == drug_name].response.values,
            average="micro",
        )
        aupr = average_precision_score(
            y_true[y_true.drug_name == drug_name].response.values,
            y_pred[y_pred.drug_name == drug_name].response.values,
            average="micro",
        )
        # Choosing the right threshold for F1, accuracy and precision calculation from ref: https://machinelearningmastery.com/threshold-moving-for-imbalanced-classification/
        fpr, tpr, thresholds = roc_curve(
            y_true[y_true.drug_name == drug_name].response.values,
            y_pred[y_pred.drug_name == drug_name].response.values,
        )
        J = tpr - fpr
        ix = argmax(J)
        best_thresh = thresholds[ix]
        
        f1 = f1_score(
            y_true[y_true.drug_name == drug_name].response.values,
            (y_pred[y_pred.drug_name == drug_name].response.values > best_thresh).astype(int),
        )
        acc_score = accuracy_score(
            y_true[y_true.drug_name == drug_name].response.values,
            (y_pred[y_pred.drug_name == drug_name].response.values > best_thresh).astype(int),
        )
        prec_score = precision_score(
            y_true[y_true.drug_name == drug_name].response.values,
            (y_pred[y_pred.drug_name == drug_name].response.values > best_thresh).astype(int),
        )
        rec_score = recall_score(
            y_true[y_true.drug_name == drug_name].response.values,
            (y_pred[y_pred.drug_name == drug_name].response.values > best_thresh).astype(int),
        )
        spearman_stats = stats.spearmanr(
            y_true[y_true.drug_name == drug_name].response.values,
            y_pred[y_pred.drug_name == drug_name].response.values,
        )
        mw_stats = stats.mannwhitneyu(
            y_combined[
                (y_combined.drug_name == drug_name) & (y_combined.response_y == 0)
            ].response_x.values,
            y_combined[
                (y_combined.drug_name == drug_name) & (y_combined.response_y == 1)
            ].response_x.values,
            alternative="greater",
        )
        denominator = (
            y_combined[
                (y_combined.drug_name == drug_name) & (y_combined.response_y == 0)
            ].shape[0]
            * y_combined[
                (y_combined.drug_name == drug_name) & (y_combined.response_y == 1)
            ].shape[0]
        )
        print(f"AUROC for {drug_name}: {roc}")
        print(f"AUPR for {drug_name}: {aupr}")
        print(f"F1 for {drug_name}: {f1}")
        print(f"Accuracy Score for {drug_name}: {acc_score}")
        print(f"Precision Score for {drug_name}: {prec_score}")
        print(f"Recall Score for {drug_name}: {rec_score}")
        print(
            f"Spearman for {drug_name}: {round(spearman_stats.correlation, 4)} (p-val: {round(spearman_stats.pvalue, 4)})"
        )
        print(
            f"Mann-Whitney for {drug_name}: {round(mw_stats.statistic/denominator, 4)} (p-val: {round(mw_stats.pvalue, 4)})"
        )
    except Exception as e:
        print(f"Error processing {drug_name} - {e}")

