In [None]:
import sys
# clone code for TCRP from https://github.com/idekerlab/TCRP
sys.path.append("../benchmarks/TCRP/code/")
sys.path.append("../src/")

In [None]:
import time
import argparse
import numpy as np
import random
import torch
import torch.nn.functional as F
import torch.optim as optim
import os
import glob
from torch.autograd import Variable
import sys
import torch.nn as nn
import pickle
import copy
from data_loading import *
from tcrp_utils import *
from score import *
from inner_loop import InnerLoop
from mlp import mlp
from meta_learner_cv import *

In [None]:
import numpy as np
import pandas as pd

import datetime
import logging
import os
import time
import matplotlib.pyplot as plt

from torch import nn
from torch.nn import functional as F

from functools import cached_property
from sklearn.metrics import average_precision_score, ndcg_score, roc_auc_score

from datasets_drug_filtered import (
    CellLineDataset,
    AggCategoricalAnnotatedCellLineDatasetFilteredByDrug,
    AggCategoricalAnnotatedTcgaDatasetFilteredByDrug,
    AggCategoricalAnnotatedRad51DatasetFilteredByDrug,
    TcgaDataset,
)

from utils import get_kld_loss, get_zinb_loss

from model import (
    BaseDruidModel,
    
)

from seaborn import scatterplot, boxplot

from sklearn.metrics import pairwise_distances

# cld = AggCategoricalAnnotatedCellLineDatasetFilteredByDrug(is_train=False,
#     filter_for="tcga")
# r51d = AggCategoricalAnnotatedRad51DatasetFilteredByDrug(is_train=False,
#     filter_for="rad51")
# td = AggCategoricalAnnotatedTcgaDatasetFilteredByDrug(
#     is_train=False,
#     filter_for="tcga"
#                                        )

In [None]:
sample_id = 0

In [None]:
cell_line_info = pd.read_csv(
    "../data/raw/sample_info.csv"
)
cell_line_info.rename(columns={"DepMap_ID": "depmap_id"}, inplace=True)
cell_line_info.set_index("depmap_id", inplace=True)
cell_line_info = pd.DataFrame(cell_line_info.primary_disease)
print(cell_line_info.primary_disease.unique().shape)
cell_line_info

In [None]:
K = 1
num_trials = 50
meta_batch_size = 10
inner_batch_size = 10
num_updates = 10
num_inner_updates = 1

random.seed(19)
np.random.seed(19)
torch.manual_seed(19)

layer, hidden, meta_lr, inner_lr, tissue_num = (
    1,
    20,
    0.001,
    0.001,
    12,
)  # args.layer, args.hidden, args.meta_lr, args.inner_lr, args.tissue_num

MODEL_STATE_PATH = f"/data/ajayago/druid/paper_intermediate/model_checkpoints/benchmarks/TCRP_with_tcga_raw_mutations_filtered_drugs_sample{sample_id}/"

In [None]:
def train_and_persist_model(drug_name):
#     print(drug_name)
    train_cell_line_dataset = AggCategoricalAnnotatedCellLineDatasetFilteredByDrug(
        is_train=True, filter_for="tcga", sample_id = sample_id
    )
    train_label_df = train_cell_line_dataset.y_df[
        train_cell_line_dataset.y_df.drug_name == drug_name
    ].copy()
    train_label_df.set_index(["depmap_id"], inplace=True)
    print(train_label_df.shape)

    train_feature_df = train_cell_line_dataset.raw_mutations
    train_feature_df = train_feature_df.loc[train_label_df.index.get_level_values(0)]
    filtered_cell_line_info = cell_line_info.loc[
        train_feature_df.index.get_level_values(0)
    ]
    print(filtered_cell_line_info.shape)

    ordered_train_feature_df = []
    for disease_type in filtered_cell_line_info.primary_disease.unique():
        filtered_df = filtered_cell_line_info[
            filtered_cell_line_info.primary_disease == disease_type
        ]
        curr_disease_depmap_ids = filtered_df.index.get_level_values(0)
        ordered_train_feature_df.append(train_feature_df.loc[curr_disease_depmap_ids])

    ordered_train_feature_df = pd.concat(ordered_train_feature_df)
    print(ordered_train_feature_df.shape)

    tissue_index_list = []
    curr_count = 0
    for (
        disease_type,
        sample_count,
    ) in filtered_cell_line_info.primary_disease.value_counts().iteritems():
        tissue_index_list.append(list(range(curr_count, curr_count + sample_count)))
        curr_count += sample_count
    print(curr_count)
    ordered_train_label_df = train_label_df.loc[
        ordered_train_feature_df.index.get_level_values(0)
    ].copy()
    print(ordered_train_label_df.shape)

    test_cell_line_dataset = AggCategoricalAnnotatedCellLineDatasetFilteredByDrug(
        is_train=False, filter_for="tcga", sample_id = sample_id
    )
    test_label_df = test_cell_line_dataset.y_df[
        test_cell_line_dataset.y_df.drug_name == drug_name
    ].copy()
    test_label_df.set_index(["depmap_id"], inplace=True)
    print(test_label_df.shape)

    test_feature_df = test_cell_line_dataset.raw_mutations.loc[
        test_label_df.index.get_level_values(0)
    ].copy()
    print(test_feature_df.shape)

    train_dataset = dataset(
        ordered_train_feature_df.to_numpy(),
        (1 - ordered_train_label_df.auc).values.reshape(-1, 1),
    )
    test_dataset = dataset(
        test_feature_df.to_numpy(), (1 - test_label_df.auc).values.reshape(-1, 1)
    )

    meta_dataset = train_dataset
    test_dataset = test_dataset

    (
        best_train_loss_test_corr_list,
        best_train_corr_test_corr_list,
        best_train_corr_test_scorr_list,
        best_train_scorr_test_scorr_list,
    ) = ([], [], [], [])

    for i in range(num_trials):
        meta_learner = MetaLearner(
            meta_dataset,
            test_dataset,
            K,
            meta_lr,
            inner_lr,
            layer,
            hidden,
            tissue_num,
            meta_batch_size,
            inner_batch_size,
            num_updates,
            num_inner_updates,
            tissue_index_list,
            num_trials,
        )

        (
            best_train_loss_test_corr,
            best_train_corr_test_corr,
            best_train_corr_test_scorr,
            best_train_scorr_test_scorr,
            best_model,
        ) = meta_learner.train()
        best_train_loss_test_corr_list.append(best_train_loss_test_corr)
        best_train_corr_test_corr_list.append(best_train_corr_test_corr)
        best_train_corr_test_scorr_list.append(best_train_corr_test_scorr)
        best_train_scorr_test_scorr_list.append(best_train_scorr_test_scorr)

        # Please uncomment this line to save your pre-train models
        # torch.save(best_model, model_dic + '/model_'+str(K)+'_trail_' + str(i))

    a = np.asarray(best_train_loss_test_corr_list).mean()
    b = np.asarray(best_train_corr_test_corr_list).mean()
    c = np.asarray(best_train_corr_test_scorr_list).mean()
    d = np.asarray(best_train_scorr_test_scorr_list).mean()

    print(
        "PDTC best_train_loss_test_corr:",
        float("%.3f" % a),
        "best_train_corr_test_corr",
        float("%.3f" % b),
        "best_train_corr_test_scorr",
        float("%.3f" % c),
        "best_train_scorr_test_scorr",
        float("%.3f" % d),
    )

    torch.save(
        best_model.state_dict(), f"{MODEL_STATE_PATH}/{drug_name.replace('/', '_')}.pth"
    )

In [None]:
trained_drugs = os.listdir(MODEL_STATE_PATH)
trained_drugs = [drug[:-4] for drug in trained_drugs]
len(trained_drugs)

In [None]:
ccle_dataset = AggCategoricalAnnotatedCellLineDatasetFilteredByDrug(is_train=True, filter_for="tcga", sample_id = sample_id)#.y_df.drug_name.unique())
len(ccle_dataset.y_df.drug_name.unique())

In [None]:
for drug_name in [
    "CISPLATIN",
    "PACLITAXEL",
    "5-FLUOROURACIL",
    "CYCLOPHOSPHAMIDE",
    "DOCETAXEL",
    "GEMCITABINE",
]:
    train_and_persist_model(drug_name)

In [None]:
trained_drugs = os.listdir(MODEL_STATE_PATH)
trained_drugs = [drug[:-4] for drug in trained_drugs]
len(trained_drugs)

In [None]:
# %%notify
# import dask

# from tqdm import tqdm
# from dask.distributed import Client


# client = Client()
# client.cluster.scale(10)


# futures = []
# # for drug_name in drug_list:
# for drug_name in ["CISPLATIN", "PACLITAXEL", "GEMCITABINE", "DOXORUBICIN", "OLAPARIB"]:
#     if drug_name.replace("/", "_") not in trained_drugs:
#         future = client.submit(train_and_persist_model, drug_name)
#         futures.append(future)

# results = client.gather(futures, errors="skip")
# client.shutdown()

In [None]:
import numpy as np
import pandas as pd

import datetime
import logging
import os
import time
# import torch
import random

from torch import nn
from torch.nn import functional as F

from functools import cached_property
from itertools import cycle

from torch.nn import Linear, ReLU, Sequential
from torch.utils.data import DataLoader, TensorDataset

from sklearn.metrics import average_precision_score, ndcg_score, roc_auc_score

from datasets_drug_filtered import (
    AggCategoricalAnnotatedCellLineDatasetFilteredByDrug,
    AggCategoricalAnnotatedTcgaDatasetFilteredByDrug,
    AggCategoricalAnnotatedRad51DatasetFilteredByDrug,
    CellLineDataset,
    TcgaDataset,
    Rad51Dataset,
    RANDOM_STATE,
)
from metric import NdcgMetric

# from utils import get_kld_loss, get_zinb_loss
from testbed import EvaluationTestbed
from model import BaseDruidModel

In [None]:
# for drug_name in drug_list:
#     if not os.path.exists(f"{MODEL_STATE_PATH}/{drug_name.replace('/', '_')}.pth"):
#         print(drug_name)

In [None]:
tcga_dataset = AggCategoricalAnnotatedTcgaDatasetFilteredByDrug(
    is_train=False, filter_for="tcga", sample_id = sample_id
)
tcga_dataset

In [None]:
tcga_dataset.tcga_response.drug_name.value_counts()

In [None]:
learning_rate = 0.001

# for drug_name in ["CISPLATIN", "PACLITAXEL", "GEMCITABINE", "DOXORUBICIN", "OLAPARIB"]:
for drug_name in [
    "CISPLATIN",
    "PACLITAXEL",
    "5-FLUOROURACIL",
    "CYCLOPHOSPHAMIDE",
    "DOCETAXEL",
    "GEMCITABINE",
]:

    if os.path.exists(f"{MODEL_STATE_PATH}/{drug_name.replace('/', '_')}.pth"):
        curr_model = mlp(324, 1, 20)
        curr_model.load_state_dict(
            torch.load(
                f"{MODEL_STATE_PATH}/{drug_name.replace('/', '_')}.pth",
                map_location=torch.device("cpu"),
            )
        )
        filtered_df = tcga_dataset.tcga_response[
            tcga_dataset.tcga_response.drug_name == drug_name
        ].copy()
        patients_with_pos_response = list(
            filtered_df[filtered_df.response == 1].submitter_id.values[:2]
        )
        patients_with_neg_response = list(
            filtered_df[filtered_df.response == 0].submitter_id.values[:2]
        )
        if (len(patients_with_pos_response) == 0) or ((len(patients_with_neg_response) == 0)):
            print(f"Skipping training for {drug_name} - {patients_with_pos_response}{patients_with_neg_response}")
            continue
        
        fs_criterion = torch.nn.BCEWithLogitsLoss()
        fs_optim = torch.optim.Adam(curr_model.parameters(), lr=learning_rate)
        
        required_patients = patients_with_neg_response + patients_with_pos_response
        few_shot_df = filtered_df[filtered_df.submitter_id.isin(required_patients)]
        train_tcga_x = torch.tensor(
            tcga_dataset.raw_mutations.loc[few_shot_df.submitter_id.values].to_numpy(),
            device="cpu",
            dtype=torch.float,
        )
        train_tcga_y = torch.tensor(
            few_shot_df.response.to_numpy(), device="cpu", dtype=torch.float
        )

        pred_train_tcga_y, _ = curr_model(train_tcga_x)
        fs_loss = fs_criterion(pred_train_tcga_y.flatten(), train_tcga_y)
        fs_loss.backward()
        fs_optim.step()
        
        torch.save(
            curr_model.state_dict(), f"{MODEL_STATE_PATH}/{drug_name.replace('/', '_')}.pth"
        )

    else:
        print(f"Model for {drug_name} not found")
    break

In [None]:
MODEL_STATE_PATH

In [None]:
# patient_dataset = AggCategoricalAnnotatedCellLineDataset(
#     is_train=None,
#     only_cat_one_drugs=False,
#     scale_y=False,
#     use_k_best_worst=None,
# )
# patient_dataset

In [None]:
# patient_dataset.y_df.drug_name.unique().shape

In [None]:
trained_drugs

In [None]:
class TcrpModel(BaseDruidModel):
    @cached_property
    def device(self):
        return torch.device("cpu")

    def __init__(self,):
        super(TcrpModel, self).__init__()
        self.drug_name_to_model_map = {}
        for drug_name in trained_drugs:
            if os.path.exists(f"{MODEL_STATE_PATH}/{drug_name.replace('/', '_')}.pth"):
                self.drug_name_to_model_map[drug_name] = mlp(324, 1, 20)
                self.drug_name_to_model_map[drug_name].load_state_dict(
                    torch.load(
                        f"{MODEL_STATE_PATH}/{drug_name.replace('/', '_')}.pth",
                        map_location=str(self.device),
                    )
                )
            else:
                print(f"Model for {drug_name} not found")

    def __str__(self):
        return "TCRP model"

    def forward(self, dataset):
        prediction_file_name = f"{MODEL_STATE_PATH}/predictions/{dataset.__class__.__name__}.csv"
        # If predictions exist, load them and return
        if os.path.exists(prediction_file_name):
            print(f"{prediction_file_name} exists")
            y_pred = pd.read_csv(prediction_file_name)
            y_required = pd.concat(list(dataset[: len(dataset)].values()), axis=1)
            y_required = y_required.merge(y_pred)
            column_name = (
                "pred_auc" if isinstance(dataset, CellLineDataset) else "pred_response"
            )
            return y_pred[column_name].to_numpy()

        x_df = pd.concat(list(dataset[: len(dataset)].values()), axis=1)

        results = []
        for _, row in x_df.iterrows():
            depmap_id = row[dataset.entity_identifier_name]
            drug_name = row["drug_name"]
            if drug_name in self.drug_name_to_model_map:
                result = self.drug_name_to_model_map[drug_name](
                    torch.tensor(
                        dataset.raw_mutations.loc[[depmap_id]].to_numpy(),
                        device=self.device,
                        dtype=torch.float,
                    )
                )
                results.append(result[0].item())
            else:
                results.append(np.NaN)

        y_pred = pd.concat(list(dataset[: len(dataset)].values()), axis=1)
        column_name = (
            "pred_auc" if isinstance(dataset, CellLineDataset) else "pred_response"
        )
        y_pred[column_name] = np.array(results).flatten()
        print(f"Saving DF of shape {y_pred.shape} to {prediction_file_name}")
        y_pred.to_csv(prediction_file_name)

        return np.array(results).flatten()

    def postprocess(self, dataset, np_out):

        assert len(np_out) == len(dataset)

        y_true = pd.concat(list(dataset[: len(dataset)].values()), axis=1)

        y_pred = y_true.copy()

        if isinstance(
            dataset,
            (
                TcgaDataset,
                Rad51Dataset,
            ),
        ):
            y_pred["response"] = np_out.squeeze()
            y_pred = y_pred[
                [dataset.entity_identifier_name, "drug_name", "response"]
            ].copy()
            return y_pred, y_true

        y_pred["auc"] = np_out.squeeze()

        if isinstance(dataset, (CellLineDataset)):

            y_true_pivotted = y_true.pivot_table(
                "auc", dataset.entity_identifier_name, "drug_name"
            )
            y_pred["auc"] = 1 - y_pred.auc.fillna(1)
            y_pred_pivotted = y_pred.pivot_table(
                "auc", dataset.entity_identifier_name, "drug_name"
            )
            print(y_pred_pivotted.shape)
            return y_pred_pivotted, y_true_pivotted

        return y_pred.pivot_table("auc", dataset.entity_identifier_name, "drug_name")


model = TcrpModel()
model

In [None]:
# # All datasets
# res = EvaluationTestbed.run(
#     {
#         model: [
#             AggCategoricalAnnotatedCellLineDataset(
#                 is_train=False,
#                 only_cat_one_drugs=False,
#                 scale_y=False,
#                 use_k_best_worst=None,
#             ),
#             AggCategoricalAnnotatedPdxDataset(
#                 apply_train_test_filter=False,
#                 is_train=False,
#                 only_cat_one_drugs=False,
#                 include_all_cell_line_drugs=True,
#             ),
#             AggCategoricalAnnotatedTcgaDataset(
#                 apply_train_test_filter=False,
#                 is_train=False,
#                 only_cat_one_drugs=False,
#                 include_all_cell_line_drugs=True,
#             ),
#             AggCategoricalAnnotatedMooresDataset(
#                 is_train=False, only_cat_one_drugs=False, include_all_cell_line_drugs=True
#             ),
#             AggCategoricalAnnotatedRad51Dataset(
#                 is_train=False, only_cat_one_drugs=False, include_all_cell_line_drugs=True
#             ),
#         ],
#     },
# )
# pd.set_option("display.max_rows", 100)
# res_df = pd.DataFrame(res)
# res_df.set_index(["model", "dataset", "metric"], inplace=True)
# res_df

## Drug Specific Analysis

In [None]:
from scipy import stats
from numpy import argmax
from sklearn.metrics import roc_curve

### TCGA

In [None]:
patient_dataset = AggCategoricalAnnotatedTcgaDatasetFilteredByDrug(
                is_train=False,
                filter_for="tcga", sample_id = sample_id
            )
patient_results = model(patient_dataset)
patient_pp_out = model.postprocess(patient_dataset, patient_results)
y_pred, y_true = patient_pp_out


y_pred_pivotted = y_pred.pivot_table(
                "response", "submitter_id", "drug_name"
            )
y_pred_pivotted = y_pred_pivotted.fillna(0)
dict_idx_drug = pd.DataFrame(y_pred_pivotted.columns).to_dict()["drug_name"]
dict_id_drug = {}

for patient_id, predictions in y_pred_pivotted.iterrows():

    cur_pred_scores = predictions.values
    cur_recom_drug_idx = np.argsort(cur_pred_scores)[:-11:-1]
    #
    dict_recom_drug = {}
    for idx, cur_idx in enumerate(cur_recom_drug_idx):
        dict_recom_drug[
            dict_idx_drug[cur_idx]
        ] = f"{cur_pred_scores[cur_idx]} ({idx+1})"
    #
    dict_id_drug[patient_id] = dict_recom_drug

predictions_display_tcga = pd.DataFrame.from_dict(dict_id_drug)

na_mask = y_pred.response.isna()
if na_mask.sum():
    print(
        f"[KaplanMeierFitterMetric] Found {na_mask.sum()} rows with invalid response values"
    )
    y_pred = y_pred[~na_mask]
    y_true = y_true.loc[~(na_mask.values)]
na_mask = y_true.response.isna()
y_true = y_true[~na_mask]
y_pred = y_pred[~na_mask]
print(y_pred.shape)
y_pred.head()
y_combined = y_pred.merge(y_true, on=["submitter_id", "drug_name"])

from sklearn.metrics import average_precision_score, ndcg_score, roc_auc_score, f1_score, accuracy_score, precision_score, recall_score

drugs_with_enough_support = [
    "CISPLATIN",
    "PACLITAXEL",
    "5-FLUOROURACIL",
    "CYCLOPHOSPHAMIDE",
    "DOCETAXEL",
    "GEMCITABINE",
]


# for drug_name in y_true.drug_name.unique():
for drug_name in drugs_with_enough_support:
    try:
        roc = roc_auc_score(
            y_true[y_true.drug_name == drug_name].response.values,
            y_pred[y_pred.drug_name == drug_name].response.values,
            average="micro",
        )
        aupr = average_precision_score(
            y_true[y_true.drug_name == drug_name].response.values,
            y_pred[y_pred.drug_name == drug_name].response.values,
            average="micro",
        )
        # Choosing the right threshold for F1, accuracy and precision calculation from ref: https://machinelearningmastery.com/threshold-moving-for-imbalanced-classification/
        fpr, tpr, thresholds = roc_curve(
            y_true[y_true.drug_name == drug_name].response.values,
            y_pred[y_pred.drug_name == drug_name].response.values,
        )
        J = tpr - fpr
        ix = argmax(J)
        best_thresh = thresholds[ix]
        
        f1 = f1_score(
            y_true[y_true.drug_name == drug_name].response.values,
            (y_pred[y_pred.drug_name == drug_name].response.values > best_thresh).astype(int),
#             average="micro",
        )
        acc_score = accuracy_score(
            y_true[y_true.drug_name == drug_name].response.values,
            (y_pred[y_pred.drug_name == drug_name].response.values > best_thresh).astype(int),
        )
        prec_score = precision_score(
            y_true[y_true.drug_name == drug_name].response.values,
            (y_pred[y_pred.drug_name == drug_name].response.values > best_thresh).astype(int),
        )
        rec_score = recall_score(
            y_true[y_true.drug_name == drug_name].response.values,
            (y_pred[y_pred.drug_name == drug_name].response.values > best_thresh).astype(int),
        )
        spearman_stats = stats.spearmanr(
            y_true[y_true.drug_name == drug_name].response.values,
            y_pred[y_pred.drug_name == drug_name].response.values,
        )
        mw_stats = stats.mannwhitneyu(
            y_combined[
                (y_combined.drug_name == drug_name) & (y_combined.response_y == 0)
            ].response_x.values,
            y_combined[
                (y_combined.drug_name == drug_name) & (y_combined.response_y == 1)
            ].response_x.values,
            alternative="greater",
        )
        denominator = (
            y_combined[
                (y_combined.drug_name == drug_name) & (y_combined.response_y == 0)
            ].shape[0]
            * y_combined[
                (y_combined.drug_name == drug_name) & (y_combined.response_y == 1)
            ].shape[0]
        )
        print(f"AUROC for {drug_name}: {roc}")
        print(f"AUPR for {drug_name}: {aupr}")
        print(f"F1 for {drug_name}: {f1}")
        print(f"Accuracy Score for {drug_name}: {acc_score}")
        print(f"Precision Score for {drug_name}: {prec_score}")
        print(f"Recall Score for {drug_name}: {rec_score}")
        print(
            f"Spearman for {drug_name}: {round(spearman_stats.correlation, 4)} (p-val: {round(spearman_stats.pvalue, 4)})"
        )
        print(
            f"Mann-Whitney for {drug_name}: {round(mw_stats.statistic/denominator, 4)} (p-val: {round(mw_stats.pvalue, 4)})"
        )
    except Exception as e:
        print(f"Error processing {drug_name} - {e}")


drugs_with_enough_support2 = ["CISPLATIN", "PACLITAXEL", "5-FLUOROURACIL"]

print("For CISPLATIN, PACLITAXEL and 5-FU")
roc = roc_auc_score(
    y_true[y_true.drug_name.isin(drugs_with_enough_support2)].response.values,
    y_pred[y_pred.drug_name.isin(drugs_with_enough_support2)].response.values,
    average="micro",
)
aupr = average_precision_score(
    y_true[y_true.drug_name.isin(drugs_with_enough_support2)].response.values,
    y_pred[y_pred.drug_name.isin(drugs_with_enough_support2)].response.values,
    average="micro",
)
# Choosing the right threshold for F1, accuracy and precision calculation from ref: https://machinelearningmastery.com/threshold-moving-for-imbalanced-classification/
fpr, tpr, thresholds = roc_curve(
    y_true[y_true.drug_name.isin(drugs_with_enough_support2)].response.values,
    y_pred[y_pred.drug_name.isin(drugs_with_enough_support2)].response.values,
)
J = tpr - fpr
ix = argmax(J)
best_thresh = thresholds[ix]

f1 = f1_score(
    y_true[y_true.drug_name.isin(drugs_with_enough_support2)].response.values,
    (y_pred[y_pred.drug_name.isin(drugs_with_enough_support2)].response.values> best_thresh).astype(int),
#     average="micro",
)
acc_score = accuracy_score(
    y_true[y_true.drug_name.isin(drugs_with_enough_support2)].response.values,
    (y_pred[y_pred.drug_name.isin(drugs_with_enough_support2)].response.values> best_thresh).astype(int),
)
prec_score = precision_score(
    y_true[y_true.drug_name.isin(drugs_with_enough_support2)].response.values,
    (y_pred[y_pred.drug_name.isin(drugs_with_enough_support2)].response.values> best_thresh).astype(int),
)
rec_score = recall_score(
    y_true[y_true.drug_name.isin(drugs_with_enough_support2)].response.values,
    (y_pred[y_pred.drug_name.isin(drugs_with_enough_support2)].response.values> best_thresh).astype(int),
)
spearman_stats = stats.spearmanr(
    y_true[y_true.drug_name.isin(drugs_with_enough_support2)].response.values,
    y_pred[y_pred.drug_name.isin(drugs_with_enough_support2)].response.values,
)
mw_stats = stats.mannwhitneyu(
    y_combined[
        (y_combined.drug_name.isin(drugs_with_enough_support2))
        & (y_combined.response_y == 0)
    ].response_x.values,
    y_combined[
        (y_combined.drug_name.isin(drugs_with_enough_support2))
        & (y_combined.response_y == 1)
    ].response_x.values,
    alternative="greater",
)
denominator = (
    y_combined[
        (y_combined.drug_name.isin(drugs_with_enough_support2))
        & (y_combined.response_y == 0)
    ].shape[0]
    * y_combined[
        (y_combined.drug_name.isin(drugs_with_enough_support2))
        & (y_combined.response_y == 1)
    ].shape[0]
)

print(f"Overall AUROC: {roc}")
print(f"Overall AUPR: {aupr}")
print(f"Overall F1: {f1}")
print(f"Overall Accuracy Score: {acc_score}")
print(f"Overall Precision Score: {prec_score}")
print(f"Overall Recall Score: {rec_score}")
print(
    f"Overall Spearman: {round(spearman_stats.correlation, 4)} (p-val: {round(spearman_stats.pvalue, 4)})"
)
print(
    f"Overall Mann-Whitney: {round(mw_stats.statistic/denominator, 4)} (p-val: {round(mw_stats.pvalue, 4)})"
)


