In [72]:
import time
import pickle
import torch
import copy
import torch.nn          as nn
import numpy             as np
import pandas            as pd
import matplotlib.pyplot as plt 
import seaborn           as sns

from torch.utils.data import Dataset, DataLoader

from config import (
    PATH_TO_FEATURES,
    PATH_TO_SAVED_DRUG_FEATURES
)

torch.manual_seed(42)
sns.set_theme(style="white")

---

# Experiments on the `GraphTab` approach

In this notebook we are going to expirment the approach of 
- replacing the cell-line branch by a GNN and 
- having the drug branch using tabular input.

In [17]:
# Reading the cell-line gene graphs.
with open(f'{PATH_TO_FEATURES}cl_graphs_as_dict.pkl', 'rb') as f:
    cl_graphs = pickle.load(f)

# Reading the drug response matrix.
with open(f'{PATH_TO_FEATURES}drugs_sparse_gdsc2.pkl', 'rb') as f: 
    drug_cl = pickle.load(f)  

In [18]:
print(f"Number of cell-lines/graphs: {len(list(cl_graphs.keys()))}")
print(cl_graphs['22RV1'])

Number of cell-lines/graphs: 983
Data(x=[4, 858], edge_index=[2, 83126])


In [21]:
print(f"Shape: {drug_cl.shape}")
print(f"Number of different drug id's   : {len(np.unique(drug_cl.DRUG_ID.values))}")
print(f"Number of different drug name's : {len(np.unique(drug_cl.DRUG_NAME.values))}")
drug_cl.head(10)

Shape: (135242, 5)
Number of different drug id's   : 198
Number of different drug name's : 192


Unnamed: 0,CELL_LINE_NAME,DRUG_ID,DRUG_NAME,DATASET,LN_IC50
3441054,22RV1,1003,Camptothecin,GDSC2,-3.142631
3459252,22RV1,1004,Vinblastine,GDSC2,-4.459259
3489119,22RV1,1005,Cisplatin,GDSC2,3.622285
3508920,22RV1,1006,Cytarabine,GDSC2,3.826935
3533420,22RV1,1007,Docetaxel,GDSC2,-3.835431
3551362,22RV1,1010,Gefitinib,GDSC2,4.032555
3571270,22RV1,1011,Navitoclax,GDSC2,3.963435
3589233,22RV1,1012,Vorinostat,GDSC2,0.846758
3606526,22RV1,1013,Nilotinib,GDSC2,3.93594
3627474,22RV1,1017,Olaparib,GDSC2,5.238895


In [32]:
drug_cl[['DRUG_ID', 'DRUG_NAME']].groupby(['DRUG_ID']).nunique().sort_values(['DRUG_NAME'], ascending=False).head(10)

Unnamed: 0_level_0,DRUG_NAME
DRUG_ID,Unnamed: 1_level_1
1003,1
1739,1
1799,1
1802,1
1804,1
1806,1
1807,1
1808,1
1809,1
1810,1


In [30]:
drug_cl[['DRUG_ID', 'DRUG_NAME']].groupby(['DRUG_NAME']).nunique().sort_values(['DRUG_ID'], ascending=False).head(10)
# TODO: only yake one DRUG_NAME, such that number of DRUG_NAME's equals number of DRUG_ID's and that there is a 1-to-1 mapping

Unnamed: 0_level_0,DRUG_ID
DRUG_NAME,Unnamed: 1_level_1
Dactinomycin,2
Fulvestrant,2
Oxaliplatin,2
Docetaxel,2
Ulixertinib,2
Uprosertib,2
PD173074,1
Olaparib,1
Osimertinib,1
P22077,1


- NOTE: We found that 6 `DRUG_NAME`'s occur for two `DRUG_ID`'s. 

The cell-line gene dataset is basically ready to go. It "only" needs to be transformed to a pytorch `Data` class. The graph per cell-line will be used as input to the GNN cell-line branch of the bi-modal model. However, the drug datasets is the drug response matrix. It doesn't contain the drug features. In the following subsection we will obtain the [SMILES fingerprints](https://jcheminf.biomedcentral.com/articles/10.1186/s13321-020-00445-4) for each unique `DRUG_ID`. This will later be used as the input to the drug branch of the bi-modal model.

### Transform drugs to SMILES fingerprints

In [52]:
with open(f'{PATH_TO_SAVED_DRUG_FEATURES}drug_name_fingerprints_dataframe.pkl', 'rb') as f:
    drug_name_smiles = pickle.load(f)
# drug_name_smiles.set_index(['drug_name'], inplace=True)
print(drug_name_smiles.shape)
# TODO: Note that yet there are some DRUG_NAME's which have >1 DRUG_ID
drug_name_smiles.head(5)

(367, 257)


Unnamed: 0,drug_name,0,1,2,3,4,5,6,7,8,...,246,247,248,249,250,251,252,253,254,255
0,(5Z)-7-Oxozeaenol,1,0,0,1,1,0,0,0,0,...,0,0,0,0,0,1,1,0,1,0
1,5-Fluorouracil,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
2,A-443654,0,1,0,0,0,0,0,0,0,...,0,1,0,0,0,0,0,0,0,1
3,A-770041,1,1,0,0,0,1,0,0,0,...,0,0,0,1,0,1,0,0,0,0
4,A-83-01,0,0,0,1,1,0,0,0,0,...,0,1,0,0,0,0,0,0,0,0


In [57]:
all_non_uniqs = drug_cl[['DRUG_ID', 'DRUG_NAME']].groupby(['DRUG_NAME']).nunique().sort_values(['DRUG_ID'], ascending=False)
all_non_uniqs = all_non_uniqs[all_non_uniqs.DRUG_ID>1].reset_index()

smiles = []
for smile in list(drug_name_smiles.drug_name.values):
    if smile in list(all_non_uniqs.DRUG_NAME.values): 
        smiles.append(smile)
        print(f"{smile} is there and has not 1-to-1 mapping")

Docetaxel is there and has not 1-to-1 mapping
Ulixertinib is there and has not 1-to-1 mapping
Uprosertib is there and has not 1-to-1 mapping


In [43]:
drug_cl[drug_cl.DRUG_NAME=='Docetaxel']

Unnamed: 0,CELL_LINE_NAME,DRUG_ID,DRUG_NAME,DATASET,LN_IC50
3533420,22RV1,1007,Docetaxel,GDSC2,-3.835431
5096727,22RV1,1819,Docetaxel,GDSC2,-2.048442
3532924,23132-87,1007,Docetaxel,GDSC2,-5.663205
5096517,23132-87,1819,Docetaxel,GDSC2,-4.758792
3518116,42-MG-BA,1007,Docetaxel,GDSC2,-5.157606
...,...,...,...,...,...
3534191,YT,1007,Docetaxel,GDSC2,-4.989452
3515605,ZR-75-30,1007,Docetaxel,GDSC2,-1.275373
5096160,ZR-75-30,1819,Docetaxel,GDSC2,3.934073
3537528,huH-1,1007,Docetaxel,GDSC2,-2.747687


We will remove these for now. However, this needs to be tackled in the future.

In [47]:
drug_cl_v2 = drug_cl[~drug_cl.DRUG_NAME.isin(smiles)]
print(drug_cl_v2.shape)

(130826, 5)


In [59]:
drug_name_smiles_v2 = drug_name_smiles[~drug_name_smiles.drug_name.isin(smiles)]
print(drug_name_smiles_v2.shape)

(364, 257)


These are now the drug response matrix and the SMILES matrix without the non-1-to-1 mapped `DRUG_NAME` and `DRUG_ID`'s.

In [69]:
drug_name_smiles_v3 = pd.merge(left=drug_name_smiles_v2,
                                right=drug_cl_v2[['DRUG_ID', 'DRUG_NAME']],
                                how='left',
                                left_on=['drug_name'],
                                right_on=['DRUG_NAME'])
drug_name_smiles_v3.drop_duplicates(inplace=True)
drug_name_smiles_v3.drop(['DRUG_NAME'], axis=1, inplace=True)
drug_name_smiles_v3 = drug_name_smiles_v3[~drug_name_smiles_v3.DRUG_ID.isna()]
drug_name_smiles_v3['DRUG_ID'] = drug_name_smiles_v3.DRUG_ID.astype(np.int64)
drug_name_smiles_v3.rename(columns={'drug_name': 'DRUG_NAME'}, inplace=True)
drug_name_smiles_v3.insert(1, 'DRUG_ID', drug_name_smiles_v3.pop('DRUG_ID'))
print(drug_name_smiles_v3.shape)
drug_name_smiles_v3.head(5)

(152, 258)


Unnamed: 0,DRUG_NAME,DRUG_ID,0,1,2,3,4,5,6,7,...,246,247,248,249,250,251,252,253,254,255
1,5-Fluorouracil,1073,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
810,ABT737,1910,1,1,0,0,0,0,0,0,...,1,0,0,0,0,1,0,0,0,1
1562,AGI-5198,1913,0,1,1,0,1,0,0,0,...,0,0,0,0,1,1,0,0,0,0
2314,AGI-6780,1634,0,0,0,0,0,0,0,0,...,1,0,0,0,0,1,0,0,0,1
3044,AMG-319,2045,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0


In [71]:
fingerprints = drug_name_smiles_v3.loc[:, ~drug_name_smiles_v3.columns.isin(['DRUG_NAME'])]
print(fingerprints.shape)
fingerprints.head(5)

(152, 257)


Unnamed: 0,DRUG_ID,0,1,2,3,4,5,6,7,8,...,246,247,248,249,250,251,252,253,254,255
1,1073,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
810,1910,1,1,0,0,0,0,0,0,0,...,1,0,0,0,0,1,0,0,0,1
1562,1913,0,1,1,0,1,0,0,0,0,...,0,0,0,0,1,1,0,0,0,0
2314,1634,0,0,0,0,0,0,0,0,0,...,1,0,0,0,0,1,0,0,0,1
3044,2045,0,1,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0


This is now the final drug dataset for the drug branch of the bi-modal network.

---

## Create PyTorch dataset

We got three datasets now.

In [75]:
# Drug response matrix holding the ln(IC50) values for each cell-line drug tuple.
drug_response_matrix = copy.deepcopy(drug_cl_v2)
print(drug_response_matrix.shape)
drug_response_matrix.head(5)

(130826, 5)


Unnamed: 0,CELL_LINE_NAME,DRUG_ID,DRUG_NAME,DATASET,LN_IC50
3441054,22RV1,1003,Camptothecin,GDSC2,-3.142631
3459252,22RV1,1004,Vinblastine,GDSC2,-4.459259
3489119,22RV1,1005,Cisplatin,GDSC2,3.622285
3508920,22RV1,1006,Cytarabine,GDSC2,3.826935
3551362,22RV1,1010,Gefitinib,GDSC2,4.032555


In [74]:
# The drug matrix holding the fingerprints for the drugs in the drug response matrix.
print(fingerprints.shape)
fingerprints.head(5)

(152, 257)


Unnamed: 0,DRUG_ID,0,1,2,3,4,5,6,7,8,...,246,247,248,249,250,251,252,253,254,255
1,1073,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
810,1910,1,1,0,0,0,0,0,0,0,...,1,0,0,0,0,1,0,0,0,1
1562,1913,0,1,1,0,1,0,0,0,0,...,0,0,0,0,1,1,0,0,0,0
2314,1634,0,0,0,0,0,0,0,0,0,...,1,0,0,0,0,1,0,0,0,1
3044,2045,0,1,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0


In [76]:
print(f"Number of cell-lines/graphs: {len(list(cl_graphs.keys()))}")
print(cl_graphs['22RV1'])

Number of cell-lines/graphs: 983
Data(x=[4, 858], edge_index=[2, 83126])
