In [3]:
import time
import pickle
import torch
import copy
import torch.nn          as nn
import numpy             as np
import pandas            as pd
import matplotlib.pyplot as plt 
import seaborn           as sns

from typing           import List
from torch.utils.data import Dataset, DataLoader
from torch_geometric.loader import DataLoader as PyG_Dataloader

from config import (
    PATH_TO_FEATURES,
    PATH_TO_SAVED_DRUG_FEATURES,
    PATH_SUMMARY_DATASETS
)

torch.manual_seed(42)
sns.set_theme(style="white")

---

# Experiments on the `TabGraph` approach

In this notebook we are going to expirment the approach of 
- having the cell-line branch using tabular input (`Tab`)
- replacing the drug branch by a GNN (`Graph`)

## Pre-processing

In [4]:
# Reading the cell-line gene graphs.
with open(f'{PATH_TO_FEATURES}cl_graphs_as_dict.pkl', 'rb') as f:
    cl_graphs = pickle.load(f)

# Reading the drug response matrix.
with open(f'{PATH_TO_FEATURES}drugs_sparse_gdsc2.pkl', 'rb') as f: 
    drug_cl = pickle.load(f)  

In [5]:
print(f"Number of cell-lines/graphs: {len(list(cl_graphs.keys()))}")
print(cl_graphs['22RV1'])

Number of cell-lines/graphs: 983
Data(x=[858, 4], edge_index=[2, 83126])


In [6]:
print(f"Shape: {drug_cl.shape}")
print(f"Number of different drug id's   : {len(np.unique(drug_cl.DRUG_ID.values))}")
print(f"Number of different drug name's : {len(np.unique(drug_cl.DRUG_NAME.values))}")
drug_cl.head(10)

Shape: (135242, 5)
Number of different drug id's   : 198
Number of different drug name's : 192


Unnamed: 0,CELL_LINE_NAME,DRUG_ID,DRUG_NAME,DATASET,LN_IC50
3441054,22RV1,1003,Camptothecin,GDSC2,-3.142631
3459252,22RV1,1004,Vinblastine,GDSC2,-4.459259
3489119,22RV1,1005,Cisplatin,GDSC2,3.622285
3508920,22RV1,1006,Cytarabine,GDSC2,3.826935
3533420,22RV1,1007,Docetaxel,GDSC2,-3.835431
3551362,22RV1,1010,Gefitinib,GDSC2,4.032555
3571270,22RV1,1011,Navitoclax,GDSC2,3.963435
3589233,22RV1,1012,Vorinostat,GDSC2,0.846758
3606526,22RV1,1013,Nilotinib,GDSC2,3.93594
3627474,22RV1,1017,Olaparib,GDSC2,5.238895


In [7]:
drug_cl[['DRUG_ID', 'DRUG_NAME']].groupby(['DRUG_ID']).nunique().sort_values(['DRUG_NAME'], ascending=False).head(10)

Unnamed: 0_level_0,DRUG_NAME
DRUG_ID,Unnamed: 1_level_1
1003,1
1739,1
1799,1
1802,1
1804,1
1806,1
1807,1
1808,1
1809,1
1810,1


In [8]:
drug_cl[['DRUG_ID', 'DRUG_NAME']].groupby(['DRUG_NAME']).nunique().sort_values(['DRUG_ID'], ascending=False).head(10)
# TODO: only yake one DRUG_NAME, such that number of DRUG_NAME's equals number of DRUG_ID's and that there is a 1-to-1 mapping

Unnamed: 0_level_0,DRUG_ID
DRUG_NAME,Unnamed: 1_level_1
Dactinomycin,2
Fulvestrant,2
Oxaliplatin,2
Docetaxel,2
Ulixertinib,2
Uprosertib,2
PD173074,1
Olaparib,1
Osimertinib,1
P22077,1


- NOTE: We found that 6 `DRUG_NAME`'s occur for two `DRUG_ID`'s. 

The cell-line gene dataset is basically ready to go. It "only" needs to be transformed to a pytorch `Data` class. The graph per cell-line will be used as input to the GNN cell-line branch of the bi-modal model. However, the drug datasets is the drug response matrix. It doesn't contain the drug features. In the following subsection we will obtain the [SMILES fingerprints](https://jcheminf.biomedcentral.com/articles/10.1186/s13321-020-00445-4) for each unique `DRUG_ID`. This will later be used as the input to the drug branch of the bi-modal model.

### Transform drugs to SMILES fingerprints

In [9]:
with open(f'{PATH_TO_SAVED_DRUG_FEATURES}drug_name_fingerprints_dataframe.pkl', 'rb') as f:
    drug_name_smiles = pickle.load(f)
# drug_name_smiles.set_index(['drug_name'], inplace=True)
print(drug_name_smiles.shape)
# TODO: Note that yet there are some DRUG_NAME's which have >1 DRUG_ID
drug_name_smiles.head(5)

(367, 257)


Unnamed: 0,drug_name,0,1,2,3,4,5,6,7,8,...,246,247,248,249,250,251,252,253,254,255
0,(5Z)-7-Oxozeaenol,1,0,0,1,1,0,0,0,0,...,0,0,0,0,0,1,1,0,1,0
1,5-Fluorouracil,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
2,A-443654,0,1,0,0,0,0,0,0,0,...,0,1,0,0,0,0,0,0,0,1
3,A-770041,1,1,0,0,0,1,0,0,0,...,0,0,0,1,0,1,0,0,0,0
4,A-83-01,0,0,0,1,1,0,0,0,0,...,0,1,0,0,0,0,0,0,0,0


In [10]:
all_non_uniqs = drug_cl[['DRUG_ID', 'DRUG_NAME']].groupby(['DRUG_NAME']).nunique().sort_values(['DRUG_ID'], ascending=False)
all_non_uniqs = all_non_uniqs[all_non_uniqs.DRUG_ID>1].reset_index()

smiles = []
for smile in list(drug_name_smiles.drug_name.values):
    if smile in list(all_non_uniqs.DRUG_NAME.values): 
        smiles.append(smile)
        print(f"{smile} is there and has not 1-to-1 mapping")

Docetaxel is there and has not 1-to-1 mapping
Ulixertinib is there and has not 1-to-1 mapping
Uprosertib is there and has not 1-to-1 mapping


In [11]:
drug_cl[drug_cl.DRUG_NAME=='Docetaxel']

Unnamed: 0,CELL_LINE_NAME,DRUG_ID,DRUG_NAME,DATASET,LN_IC50
3533420,22RV1,1007,Docetaxel,GDSC2,-3.835431
5096727,22RV1,1819,Docetaxel,GDSC2,-2.048442
3532924,23132-87,1007,Docetaxel,GDSC2,-5.663205
5096517,23132-87,1819,Docetaxel,GDSC2,-4.758792
3518116,42-MG-BA,1007,Docetaxel,GDSC2,-5.157606
...,...,...,...,...,...
3534191,YT,1007,Docetaxel,GDSC2,-4.989452
3515605,ZR-75-30,1007,Docetaxel,GDSC2,-1.275373
5096160,ZR-75-30,1819,Docetaxel,GDSC2,3.934073
3537528,huH-1,1007,Docetaxel,GDSC2,-2.747687


We will remove these for now. However, this needs to be tackled in the future.

In [12]:
drug_cl_v2 = drug_cl[~drug_cl.DRUG_NAME.isin(smiles)]
print(drug_cl_v2.shape)

(130826, 5)


In [13]:
drug_name_smiles_v2 = drug_name_smiles[~drug_name_smiles.drug_name.isin(smiles)]
print(drug_name_smiles_v2.shape)

(364, 257)


These are now the drug response matrix and the SMILES matrix without the non-1-to-1 mapped `DRUG_NAME` and `DRUG_ID`'s.

In [14]:
drug_name_smiles_v3 = pd.merge(left=drug_name_smiles_v2,
                                right=drug_cl_v2[['DRUG_ID', 'DRUG_NAME']],
                                how='left',
                                left_on=['drug_name'],
                                right_on=['DRUG_NAME'])
drug_name_smiles_v3.drop_duplicates(inplace=True)
drug_name_smiles_v3.drop(['DRUG_NAME'], axis=1, inplace=True)
drug_name_smiles_v3 = drug_name_smiles_v3[~drug_name_smiles_v3.DRUG_ID.isna()]
drug_name_smiles_v3['DRUG_ID'] = drug_name_smiles_v3.DRUG_ID.astype(np.int64)
drug_name_smiles_v3.rename(columns={'drug_name': 'DRUG_NAME'}, inplace=True)
drug_name_smiles_v3.insert(1, 'DRUG_ID', drug_name_smiles_v3.pop('DRUG_ID'))
print(drug_name_smiles_v3.shape)
drug_name_smiles_v3.head(5)

(152, 258)


Unnamed: 0,DRUG_NAME,DRUG_ID,0,1,2,3,4,5,6,7,...,246,247,248,249,250,251,252,253,254,255
1,5-Fluorouracil,1073,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
810,ABT737,1910,1,1,0,0,0,0,0,0,...,1,0,0,0,0,1,0,0,0,1
1562,AGI-5198,1913,0,1,1,0,1,0,0,0,...,0,0,0,0,1,1,0,0,0,0
2314,AGI-6780,1634,0,0,0,0,0,0,0,0,...,1,0,0,0,0,1,0,0,0,1
3044,AMG-319,2045,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0


## Obtain graphs for the drugs

Instead of using the SMILES fingerprint we will use a graph of molecules and bonds for each drugs. We then further use a GNN to find an embedding for each drug.

__Note__: Compared to the cell-line drugs, in the encoding of the drugs as graphs, _not_ every graph will have the same topology.

In [15]:
fingerprints = drug_name_smiles_v3.loc[:, ~drug_name_smiles_v3.columns.isin(['DRUG_NAME'])]
print(fingerprints.shape)
fingerprints.head(5)

(152, 257)


Unnamed: 0,DRUG_ID,0,1,2,3,4,5,6,7,8,...,246,247,248,249,250,251,252,253,254,255
1,1073,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
810,1910,1,1,0,0,0,0,0,0,0,...,1,0,0,0,0,1,0,0,0,1
1562,1913,0,1,1,0,1,0,0,0,0,...,0,0,0,0,1,1,0,0,0,0
2314,1634,0,0,0,0,0,0,0,0,0,...,1,0,0,0,0,1,0,0,0,1
3044,2045,0,1,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0


In [14]:
# Select only the drugs in the drug response matrix which also have a SMILES string.
drug_cl_v3 = drug_cl_v2[drug_cl_v2.DRUG_ID.isin(list(fingerprints.DRUG_ID.values))]
print(drug_cl_v3.shape)
print(f"Number of unique cell-lines : {len(list(np.unique(drug_cl_v3.CELL_LINE_NAME.values)))}")
drug_cl_v3.head(5)

(101370, 5)
Number of unique cell-lines : 809


Unnamed: 0,CELL_LINE_NAME,DRUG_ID,DRUG_NAME,DATASET,LN_IC50
3441054,22RV1,1003,Camptothecin,GDSC2,-3.142631
3459252,22RV1,1004,Vinblastine,GDSC2,-4.459259
3508920,22RV1,1006,Cytarabine,GDSC2,3.826935
3551362,22RV1,1010,Gefitinib,GDSC2,4.032555
3571270,22RV1,1011,Navitoclax,GDSC2,3.963435


In [15]:
# Select only the cell-line graphs for which the cell-line is also in the drug-response matrix.
cl_graphs_v2 = copy.deepcopy(cl_graphs)
not_in = []
cls = list(np.unique(drug_cl_v3.CELL_LINE_NAME.values))
for cl in list(cl_graphs.keys()):
    if cl not in cls: 
        not_in.append(cl)
        cl_graphs_v2.pop(cl, None)

print(f"Number of cell-lines before    : {len(list(cl_graphs.keys()))}")
print(f"Number of not-found cell-lines : {len(not_in)}")
print(f"Number of cell-lines after     : {len(list(cl_graphs_v2.keys()))}")

Number of cell-lines before    : 983
Number of not-found cell-lines : 177
Number of cell-lines after     : 806


Since the cell-line graph dataset has less unique cell-lines then the drug-response matrix we need to remove the cell-line from the drug-response matrix which dont have a graph.

In [16]:
cls_with_no_graph = set(np.unique(drug_cl_v3.CELL_LINE_NAME.values)).difference(set(cl_graphs_v2.keys()))
drug_cl_v4 = drug_cl_v3[~drug_cl_v3.CELL_LINE_NAME.isin(cls_with_no_graph)]
print(drug_cl_v4.shape)
print(f"Number of unique cell-lines : {len(list(np.unique(drug_cl_v4.CELL_LINE_NAME.values)))}")
drug_cl_v4.head(5)


(100972, 5)
Number of unique cell-lines : 806


Unnamed: 0,CELL_LINE_NAME,DRUG_ID,DRUG_NAME,DATASET,LN_IC50
3441054,22RV1,1003,Camptothecin,GDSC2,-3.142631
3459252,22RV1,1004,Vinblastine,GDSC2,-4.459259
3508920,22RV1,1006,Cytarabine,GDSC2,3.826935
3551362,22RV1,1010,Gefitinib,GDSC2,4.032555
3571270,22RV1,1011,Navitoclax,GDSC2,3.963435


This is now the final drug dataset for the drug branch of the bi-modal network.

---

## Dataset summary

We got three datasets now: 
- Drug response matrix
- fingerprint matrix
- cell-line graph dictionary

In [17]:
# Drug response matrix holding the ln(IC50) values for each cell-line drug tuple.
drug_response_matrix = copy.deepcopy(drug_cl_v4)
print(drug_response_matrix.shape)
drug_response_matrix.sort_values(['DRUG_ID']).head(5)

(100972, 5)


Unnamed: 0,CELL_LINE_NAME,DRUG_ID,DRUG_NAME,DATASET,LN_IC50
3441054,22RV1,1003,Camptothecin,GDSC2,-3.142631
3441397,NCI-H1975,1003,Camptothecin,GDSC2,-1.579555
3424961,CL-34,1003,Camptothecin,GDSC2,-2.68195
3443434,LNZTA3WT4,1003,Camptothecin,GDSC2,-1.564607
3440494,A101D,1003,Camptothecin,GDSC2,-2.864103


In [18]:
# The drug matrix holding the fingerprints for the drugs in the drug response matrix.
print(fingerprints.shape)
fingerprints.sort_values(['DRUG_ID']).head(5)

(152, 257)


Unnamed: 0,DRUG_ID,0,1,2,3,4,5,6,7,8,...,246,247,248,249,250,251,252,253,254,255
21273,1003,0,0,0,0,0,0,0,0,0,...,0,0,1,0,1,0,0,0,0,0
94088,1004,1,0,0,0,0,0,0,0,1,...,0,1,0,1,0,1,0,0,0,0
25052,1006,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,1,0,0,0,0
39836,1010,1,0,0,0,0,0,0,0,0,...,0,0,0,1,0,0,0,0,0,1
61065,1011,0,1,0,0,0,0,0,0,0,...,1,0,0,0,1,1,0,0,0,1


In [19]:
# Drugs as dictionary.
fingerprints_dict = fingerprints.set_index('DRUG_ID').T.to_dict('list')
print(len(fingerprints_dict.keys()))
print(len(fingerprints_dict[1003]))

152
256


In [20]:
# The cell-line graphs holding for each cell-line in the drug-response-matrix the corresponding 
# graph with the cell-line level gene features.
cell_line_graphs = copy.deepcopy(cl_graphs_v2)
print(f"Number of cell-lines/graphs: {len(list(cell_line_graphs.keys()))}")
print(cell_line_graphs['22RV1'])

Number of cell-lines/graphs: 806
Data(x=[858, 4], edge_index=[2, 83126])


---

## Create PyTorch dataset

dataset should be 

| Tuple Datapoint | Cell Branch Input | Drug Branch Input | Target | Example Dataset |
| --------------- | ----------------- | ----------------- | ------ | --------------- |
| `(cl_1, drug_1)`| $\text{graph}_{\text{cl}_1}$ | $\text{smiles}_{\text{drug}_1}$ | $ln(IC50)_{\text{cl}_1\text{drug}_1}$ | Train |
| `(cl_1, ...   )`| $\text{graph}_{\text{cl}_1}$ | $\text{smiles}_{\text{drug}_{...}}$ | $ln(IC50)_{\text{cl}_1\text{drug}_{...}}$ | Train | 
| `(cl_1, drug_m)`| $\text{graph}_{\text{cl}_1}$ | $\text{smiles}_{\text{drug}_m}$ | $ln(IC50)_{\text{cl}_1\text{drug}_m}$ | Test | 
| `(cl_i, drug_1)`| $\text{graph}_{\text{cl}_i}$ | $\text{smiles}_{\text{drug}_1}$ | $ln(IC50)_{\text{cl}_i\text{drug}_1}$ | Test |
| `(cl_i, ...   )`| $\text{graph}_{\text{cl}_i}$ | $\text{smiles}_{\text{drug}_{...}}$ | $ln(IC50)_{\text{cl}_i\text{drug}_{...}}$ | Train | 
| `(cl_i, drug_m)`| $\text{graph}_{\text{cl}_i}$ | $\text{smiles}_{\text{drug}_m}$ | $ln(IC50)_{\text{cl}_i\text{drug}_m}$ | Train |
| `(cl_n, drug_1)`| $\text{graph}_{\text{cl}_n}$ | $\text{smiles}_{\text{drug}_1}$ | $ln(IC50)_{\text{cl}_n\text{drug}_1}$ | Test | 
| `(cl_n, ...   )`| $\text{graph}_{\text{cl}_n}$ | $\text{smiles}_{\text{drug}_{...}}$ | $ln(IC50)_{\text{cl}_n\text{drug}_{...}}$ | Test |
| `(cl_n, drug_m)`| $\text{graph}_{\text{cl}_n}$ | $\text{smiles}_{\text{drug}_m}$ | $ln(IC50)_{\text{cl}_n\text{drug}_m}$ | Train | 

In [21]:
drug_response_matrix.head(5)

Unnamed: 0,CELL_LINE_NAME,DRUG_ID,DRUG_NAME,DATASET,LN_IC50
3441054,22RV1,1003,Camptothecin,GDSC2,-3.142631
3459252,22RV1,1004,Vinblastine,GDSC2,-4.459259
3508920,22RV1,1006,Cytarabine,GDSC2,3.826935
3551362,22RV1,1010,Gefitinib,GDSC2,4.032555
3571270,22RV1,1011,Navitoclax,GDSC2,3.963435


In [22]:
type(cl_graphs_v2['TE-9'])

torch_geometric.data.data.Data

In [23]:
type(fingerprints.loc[fingerprints.DRUG_ID==1073].values[0, 1:])

numpy.ndarray

In [24]:
fingerprints.set_index('DRUG_ID').T.to_dict('list')[1073][:20]

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

In [25]:
type(drug_response_matrix['LN_IC50'].iloc[1])

numpy.float64

In [70]:
from torch_geometric.data import Data

# class GraphTabDataset(Dataset): 
#     def __init__(self, cl_graphs, drugs, drug_response_matrix):
#         super().__init__()

#         # SMILES fingerprints of the drugs and cell-line graphs.
#         self.drugs = drugs
#         self.cell_line_graphs = cl_graphs

#         # Lookup datasets for the response values.
#         drug_response_matrix.reset_index(drop=True, inplace=True)
#         self.cell_lines = drug_response_matrix['CELL_LINE_NAME']
#         self.drug_ids = drug_response_matrix['DRUG_ID']
#         self.drug_names = drug_response_matrix['DRUG_NAME']
#         self.ic50s = drug_response_matrix['LN_IC50']

#     def __len__(self):
#         return len(self.ic50s)

#     def __getitem__(self, idx: int) -> tuple[Data, np.ndarray, float]:
#         """
#         Returns a tuple of cell-line, drug and the corresponding ln(IC50)
#         value for a given index.

#         Args:
#             idx (int): Index to specify the row in the drug response matrix.  
#         Returns:
#             Tuple[torch_geometric.data.data.Data, np.ndarray, np.float64]
#             Tuple of a cell-line graph, drug SMILES fingerprint and the 
#             corresponding ln(IC50) value.
#         """
#         return (self.cell_line_graphs[self.cell_lines.iloc[idx]], 
#                 self.drugs[self.drug_ids.iloc[idx]],
#                 self.ic50s.iloc[idx])

    # def print_dataset_summary(self):
    #     print(f"GraphTabDataset Summary")
    #     print(f"{25*'='}")
    #     print(f"N_obs        : {1}")
    #     print(f"N_cell_lines : {2}")
    #     print(f"N_drugs      : {3}")
    #     print(f"N_genes      : {4}")

class GraphTabDataset_prev(Dataset):
    """ This class encodes the dataset used for building the Graph-Tab model. 
        In this bi-modal model
        - the cell-line branch is encoded by a GNN
        - the drug branch is encoded by a NN
    """
    def __init__(
            self,
            dr_mat,
            drugs,
            cl_graphs,
        ):
        super().__init__()

        # Cell-line branch with gene features.
        #   key: CELL_LINE_NAME
        #   value: gene feature graph
        X_cl = copy.deepcopy(cl_graphs)
        
        # SMILES drug branch.
        #   column 1: DRUG_ID
        #   col 1..n: SMILES fingerprints
        X_dr = copy.deepcopy(drugs)

        # Targets for each cell-line-drug pair.
        y = dr_mat[['CELL_LINE_NAME', 'DRUG_ID', 'LN_IC50']]

        self.drug_id_name = dr_mat[['DRUG_ID', 'DRUG_NAME']]
        self.cell_lines = list(X_cl.keys())
        self.drugs = list(np.unique(X_dr.DRUG_ID.values))

        assert len(self.cell_lines) == len(list(np.unique(y.CELL_LINE_NAME.values))), \
            f"Graph has {len(self.cell_lines)} while drug response matrix has {len(list(np.unique(y.CELL_LINE_NAME.values)))}!"
        assert len(self.drugs) == len(list(np.unique(y.DRUG_ID.values))), \
            f"There are {len(self.cell_lines)} SMILES fingerprints while drug response matrix has {len(list(np.unique(y.CELL_LINE_NAME.values)))} drug id's!"            

        X_dr.set_index(['DRUG_ID'], inplace=True)
        y.set_index(['CELL_LINE_NAME', 'DRUG_ID'], inplace=True)

        # Save whole dataframes with index and column names.
        self.X_cl = X_cl 
        self.X_dr = X_dr 
        # TODO: save cl's and drugs. Same list as list of tuple of (CELL_LINE_NAME, DRUG_ID); or separate lists
        self.y = y

        self.N_obs = self.y.shape[0]
        self.N_cell_lines = len(self.cell_lines)
        self.N_drugs = len(self.drugs)
        self.N_genes = self.X_cl[next(iter(self.X_cl))].x.shape[0]        

        # Save only the numerical values.
        # self.X_cl_values = torch.tensor(self.X_cl.values())
        self.y_vals = torch.tensor(self.y.LN_IC50.values, dtype=torch.float64)

    def __len__(self): 
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.y_vals[idx]

    def __getitem__(self, cl_idx, drug_idx):
        """Returns response value as ln_IC50.
        TODO: one index, 
        return x_cell: graph, x_drug: fps, y: lnic50
        """
        return self.y.loc[cl_idx, drug_idx]

    def get_number_of_obs(self):
        return self.N_obs

    def get_number_of_cell_lines(self):
        return self.N_cell_lines

    def get_number_of_drugs(self):
        return self.N_drugs

    def get_number_of_genes(self):
        return self.N_genes
    
    def print_dataset_summary(self):
        print(f"GraphTabDataset Summary")
        print(f"{25*'='}")
        print(f"N_obs        : {self.N_obs}")
        print(f"N_cell_lines : {self.N_cell_lines}")
        print(f"N_drugs      : {self.N_drugs}")
        print(f"N_genes      : {self.N_genes}")

    def get_cell_line_graph(self, cl):
        """Get gene-interaction graph for a given cell-line."""
        assert cl in self.cell_lines, f"Cell-line {cl} has no saved graph!"
        return self.X_cl[cl]

    def get_smiles_fingerprint(self, drug_id):
        """Get SMILES fingerprint for a given drug"""
        assert drug_id in self.drugs, f"Drug {drug_id} has no SMILES fingerprint!"
        return self.X_dr.loc[drug_id]

    def get_drug_name(self, drug_id):
        assert drug_id in list(self.drug_id_name.DRUG_ID.values),\
            f"Drug id {drug_id} is not in the dataset and thus has no DRUG_NAME mapping!"
        return self.drug_id_name[self.drug_id_name.DRUG_ID==drug_id].DRUG_NAME


TypeError: Type Tuple cannot be instantiated; use tuple() instead

In [40]:
from typing import Tuple
from torch_geometric.data import Data, Dataset

class GraphTabDataset(Dataset): 
    def __init__(self, cl_graphs, drugs, drug_response_matrix):
        super().__init__()

        # SMILES fingerprints of the drugs and cell-line graphs.
        self.drugs = drugs
        self.cell_line_graphs = cl_graphs

        # Lookup datasets for the response values.
        drug_response_matrix.reset_index(drop=True, inplace=True)
        self.cell_lines = drug_response_matrix['CELL_LINE_NAME']
        self.drug_ids = drug_response_matrix['DRUG_ID']
        self.drug_names = drug_response_matrix['DRUG_NAME']
        self.ic50s = drug_response_matrix['LN_IC50']

    def __len__(self):
        return len(self.ic50s)

    def __getitem__(self, idx: int):
        """
        Returns a tuple of cell-line, drug and the corresponding ln(IC50)
        value for a given index.

        Args:
            idx (`int`): Index to specify the row in the drug response matrix.  
        Returns:
            `Tuple[torch_geometric.data.data.Data, np.ndarray, np.float64]`:
            Tuple of a cell-line graph, drug SMILES fingerprint and the 
            corresponding ln(IC50) value.
        """
        return (self.cell_line_graphs[self.cell_lines.iloc[idx]], 
                self.drugs[self.drug_ids.iloc[idx]],
                self.ic50s.iloc[idx])

    def print_dataset_summary(self):
        print(f"GraphTabDataset Summary")
        print(f"{23*'='}")
        print(f"# observations : {len(self.ic50s)}")
        print(f"# cell-lines   : {len(np.unique(self.cell_lines))}")
        print(f"# drugs        : {len(np.unique(self.drug_names))}")
        print(f"# genes        : {self.cell_line_graphs[next(iter(self.cell_line_graphs))].x.shape[0]}")

In [41]:
graph_tab_dataset = GraphTabDataset(cl_graphs=cell_line_graphs,
                                    drugs=fingerprints_dict,
                                    drug_response_matrix=drug_response_matrix)

graph_tab_dataset.print_dataset_summary()                                    

GraphTabDataset Summary
# observations : 100972
# cell-lines   : 806
# drugs        : 152
# genes        : 858


In [42]:
random_seed = 12345

BATCH_SIZE = 1_000 # TODO: tune batch_size
LR = 0.001 # TODO: tune this or implement lr decay
TRAIN_RATIO = 0.8 
TEST_VAL_RATIO = 1-TRAIN_RATIO # How much of all data is for the test and validation set.
VAL_RATIO = 0.5 # How much of the of the test and validation set is only for validation.
NUM_EPOCHS = 100

In [43]:
torch.stack([torch.tensor([1,2]), torch.tensor([3,4]), torch.tensor([5,6])], 0)

tensor([[1, 2],
        [3, 4],
        [5, 6]])

In [70]:
from sklearn.model_selection import train_test_split
from torch_geometric.data import Batch
from torch.utils.data import DataLoader
# from torch_geometric.loader import DataLoader

def _collate_graph_tab(samples):
    """
    Collates a list of cell-line graphs, SMILES drug fingerprints and 
    ln(IC50) values to pytorch handable formats.

    Args: 
        samples (`List[Tuple[torch.Data, List[int], float]]`): List of tuples of 
            cell-line graphs, SMILES drug fingerprints and ln(IC50) values.
            Example:
            >>> [(Data(x=[858, 4], edge_index=[2, 83126]), [0, 0, 1, ..., 0], 4.792643),
                 (Data(x=[858, 4], edge_index=[2, 83126]), [1, 0, 1, ..., 1], 5.857639),
                 ...]
            
    Returns:
        `DataBatch`: All graphs in the batch as one big (disconnected) graph.
            Example: 
            >>> DataBatch(x=[858000, 4], edge_index=[2, 83126000], batch=[858000], ptr=[1001])
                # x=[858*BATCH_SIZE, 4], edge_index=[2, 83126*BATCH_SIZE], batch=[858*BATCH_SIZE]
        `torch.tensor(List[List[int]])`: tensor of list of drug fingerprints.
            Example:
            >>> torch.tensor([[0,0,1,...,0], 
                              [1,0,1,...,1],
                              ...])
        `torch.tensor(float)`: tensor of ln(IC50)'s.
            Example: 
            >>> torch.tensor(4.792643, 5.857639, ...)
    """
    cells, drugs, targets = map(list, zip(*samples))
    drugs = [torch.tensor(drug_fp, dtype=torch.float64) for drug_fp in drugs] # list of fingerprint tensors

    return Batch.from_data_list(cells), torch.stack(drugs, 0), torch.tensor(targets)
    # return "hal.o", torch.stack(drugs, 0), torch.tensor(targets)

def create_datasets(drm, cl_dict, drug_dict):
    print(f"Full     shape: {drm.shape}")
    train_set, test_val_set = train_test_split(drm, 
                                               test_size=TEST_VAL_RATIO, 
                                               random_state=random_seed,
                                               stratify=drm['CELL_LINE_NAME'])
    test_set, val_set = train_test_split(test_val_set,
                                         test_size=VAL_RATIO,
                                         random_state=random_seed,
                                         stratify=test_val_set['CELL_LINE_NAME'])
    print(f"train    shape: {train_set.shape}")
    print(f"test_val shape: {test_val_set.shape}")
    print(f"test     shape: {test_set.shape}")
    print(f"val      shape: {val_set.shape}")

    train_dataset = GraphTabDataset(cl_graphs=cl_dict,
                                    drugs=drug_dict,
                                    drug_response_matrix=train_set)
    test_dataset = GraphTabDataset(cl_graphs=cl_dict,
                                   drugs=drug_dict,
                                   drug_response_matrix=test_set)
    val_dataset = GraphTabDataset(cl_graphs=cl_dict,
                                  drugs=drug_dict,
                                  drug_response_matrix=val_set)

    print("\ntrain_dataset:")
    train_dataset.print_dataset_summary()
    print("\n\ntest_dataset:")
    test_dataset.print_dataset_summary()
    print("\n\nval_dataset:")
    val_dataset.print_dataset_summary()

    # TODO: try out different `num_workers`.
    train_loader = DataLoader(dataset=train_dataset, 
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              collate_fn=_collate_graph_tab)
    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=BATCH_SIZE,
                             shuffle=True,
                             collate_fn=_collate_graph_tab)
    val_loader = DataLoader(dataset=val_dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=True,
                            collate_fn=_collate_graph_tab)

    return train_loader, test_loader, val_loader

train_loader, test_loader, val_loader = create_datasets(drug_response_matrix, 
                                                        cl_graphs_v2,
                                                        fingerprints_dict)

Full     shape: (100972, 5)
train    shape: (80777, 5)
test_val shape: (20195, 5)
test     shape: (10097, 5)
val      shape: (10098, 5)

train_dataset:
GraphTabDataset Summary
# observations : 80777
# cell-lines   : 806
# drugs        : 152
# genes        : 858


test_dataset:
GraphTabDataset Summary
# observations : 10097
# cell-lines   : 806
# drugs        : 151
# genes        : 858


val_dataset:
GraphTabDataset Summary
# observations : 10098
# cell-lines   : 806
# drugs        : 152
# genes        : 858


In [72]:
for step, data in enumerate(train_loader):
    cell, drugs, targets = data
    print(f'Step {step + 1}:')
    print(f'=======')
    print(f'Number of graphs in the current batch: {cell.num_graphs}')
    print(cell)
    # print(drugs)
    # print(targets)
    print()

    if step == 3:
        break

Step 1:
Number of graphs in the current batch: 1000
DataBatch(x=[858000, 4], edge_index=[2, 83126000], batch=[858000], ptr=[1001])

Step 2:
Number of graphs in the current batch: 1000
DataBatch(x=[858000, 4], edge_index=[2, 83126000], batch=[858000], ptr=[1001])

Step 3:
Number of graphs in the current batch: 1000
DataBatch(x=[858000, 4], edge_index=[2, 83126000], batch=[858000], ptr=[1001])

Step 4:
Number of graphs in the current batch: 1000
DataBatch(x=[858000, 4], edge_index=[2, 83126000], batch=[858000], ptr=[1001])



In [73]:
# for batch_cell_graph, batch_drugs, batch_ic50s in train_loader
print("Number of batches per dataset:")
print(f"  train : {len(train_loader)}")
print(f"  test  : {len(test_loader)}")
print(f"  val   : {len(val_loader)}")

Number of batches per dataset:
  train : 81
  test  : 11
  val   : 11


By using `Batch.from_data_list(cells)` we batched all graphs in a batch into a single giant graph.

__TODO__: 
- [ ] We can also use the class `torch_geometric.data.DataLoader` directly to do this.
- example notebook: https://colab.research.google.com/drive/1I8a0DfQ3fI7Njc62__mVXUlcAleUclnb?usp=sharing#scrollTo=G-DwBYkquRUN

- PyTorch.data doc: https://pytorch.org/docs/stable/data.html

In [74]:
for data in train_loader: 
    cell, drug, ic = data
    print(drug)
    print(drug[0][0])
    break

tensor([[0., 0., 0.,  ..., 0., 0., 1.],
        [1., 1., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 1., 0.,  ..., 0., 0., 0.],
        [1., 1., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 1., 0., 1.]], dtype=torch.float64)
tensor(0., dtype=torch.float64)


In [78]:
Batch.to_data_list(cell)[:10]

[Data(x=[858, 4], edge_index=[2, 83126]),
 Data(x=[858, 4], edge_index=[2, 83126]),
 Data(x=[858, 4], edge_index=[2, 83126]),
 Data(x=[858, 4], edge_index=[2, 83126]),
 Data(x=[858, 4], edge_index=[2, 83126]),
 Data(x=[858, 4], edge_index=[2, 83126]),
 Data(x=[858, 4], edge_index=[2, 83126]),
 Data(x=[858, 4], edge_index=[2, 83126]),
 Data(x=[858, 4], edge_index=[2, 83126]),
 Data(x=[858, 4], edge_index=[2, 83126])]

## Model development

In [None]:
%load_ext autoreload
%autoreload
from v3_GCN import GraphTab_v1
from my_utils.model_helpers import train_and_test_model

torch.manual_seed(random_seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

model = GraphTab_v1().to(device)
loss_func = nn.MSELoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=LR) # TODO: include weight_decay of lr

model, train_losses, test_losses = train_and_test_model(modeling_dataset=drug_response_matrix,
                                                        model=model,
                                                        criterion=loss_func,
                                                        optimizer=optimizer,
                                                        num_epochs=10,
                                                        device=device,
                                                        train_loader=train_loader,
                                                        test_loader=test_loader)

In [79]:
from tqdm import tqdm
import torch.nn.functional as F

class BuildModel():
    def __init__(self, model, criterion, optimizer, num_epochs, train_loader, test_loader, val_loader):
        self.train_losses = []
        self.test_losses = []
        self.val_losses = []

        self.train_loader = train_loader
        self.test_loader = test_loader
        self.val_loader = val_loader

        self.num_epochs = num_epochs

        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer

    def train(self): 
        train_losses, test_losses = [], []

        for epoch in range(self.num_epochs):
            model.train()
            print("=====Epoch {}".format(epoch))
            print("Training...")
            for i, data in enumerate(tqdm(self.train_loader, desc='Iteration')):
                cell, drug, targets = data
                cell, drug, targets = cell.to(device), drug.to(device), targets.to(device)

                print(f"targets.size : {targets.size}")
                print(f"targets      : {targets[:10]}")

                # Models predictions of the ic50s for a batch of cell-lines and drugs
                preds = self.model(cell, drug.float())
                print(100*"=")
                print(targets)
                print(targets.view(-1, 1))
                losses = self.criterion(preds, targets.view(-1, 1).float()) # =train_loss

                # TODO: maybe put this on top of the loop.
                self.optimizer.zero_grad()
                losses.backward()
                self.optimizer.step()

            train_losses.add(losses)

            # TODO: Testing
            print("Testing...")


    def validate(self, model, loader, device): 
        model.eval()

        y_true, y_pred = [], []
        total_loss = 0
        with torch.no_grad():
            for data in tqdm(loader, desc='Iteration'):
                cell, drug, targets = data
                cell, drug, targets = cell.to(device), drug.to(device), targets.to(device)

                preds = model(cell, drug)
                total_loss += F.mse_loss(preds, targets.view(-1, 1).float(), reduction='sum')
                y_true.append(targets.view(-1, 1))
                y_pred.append(preds)




In [80]:
%load_ext autoreload
%autoreload
from v3_GCN import GraphTab_v1
from my_utils.model_helpers import train_and_test_model

torch.manual_seed(random_seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

model = GraphTab_v1().to(device)
loss_func = nn.MSELoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=LR) # TODO: include weight_decay of lr

build_model = BuildModel(model=model,
                         criterion=loss_func,
                         optimizer=optimizer,
                         num_epochs=10,
                         train_loader=train_loader,
                         test_loader=test_loader,
                         val_loader=val_loader)

build_model.train()

device: cpu
self.drug_nn: Sequential(
  (0): Linear(in_features=256, out_features=128, bias=True)
  (1): ReLU()
  (2): Linear(in_features=128, out_features=128, bias=True)
  (3): ReLU()
)
self.cell_emb: Sequential(
  (0): GCNConv(4, 256)
  (1): ReLU(inplace=True)
  (2): GCNConv(256, 256)
  (3): ReLU(inplace=True)
  (4): <function global_mean_pool at 0x139209d80>
  (5): Linear(in_features=256, out_features=128, bias=True)
  (6): ReLU()
  (7): Linear(in_features=128, out_features=128, bias=True)
  (8): ReLU()
)
=====Epoch 0
Training...


Iteration:   0%|          | 0/81 [00:00<?, ?it/s]

targets.size : <built-in method size of Tensor object at 0x128c85080>
targets      : tensor([ 3.5307, -3.2560,  1.4383,  0.8531,  5.9387,  3.9552, -0.7463,  2.1936,
         2.3953,  2.6451], dtype=torch.float64)
tensor([[0., 1., 0.,  ..., 0., 0., 1.],
        [0., 1., 0.,  ..., 0., 0., 1.],
        [1., 1., 0.,  ..., 0., 0., 1.],
        ...,
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]])
drug_emb.shape: torch.Size([1000, 128])
drug_emb: tensor([[0.0000, 0.0000, 0.0000,  ..., 0.1406, 0.0000, 0.1114],
        [0.0257, 0.0363, 0.0000,  ..., 0.1095, 0.0000, 0.0000],
        [0.0835, 0.0000, 0.0000,  ..., 0.0000, 0.0171, 0.1947],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0844, 0.0000, 0.0419],
        [0.0000, 0.0000, 0.0575,  ..., 0.0405, 0.0816, 0.0192],
        [0.0000, 0.0496, 0.0000,  ..., 0.0668, 0.0524, 0.0778]],
       grad_fn=<ReluBackward0>)


Iteration:   0%|          | 0/81 [00:05<?, ?it/s]


RuntimeError: index 858000 is out of bounds for dimension 0 with size 858000

In [299]:
for step, data in enumerate(train_loader):
    cell, _, _ = data
    print(f'Step {step + 1}:')
    print(f'=======')
    print(f'Number of graphs in the current batch: {cell.num_graphs}')
    print(cell)
    print()

    if step == 1:
        break

Step 1:
Number of graphs in the current batch: 1000
DataBatch(x=[858000, 4], edge_index=[2, 83126000], batch=[858000], ptr=[1001])

Step 2:
Number of graphs in the current batch: 1000
DataBatch(x=[858000, 4], edge_index=[2, 83126000], batch=[858000], ptr=[1001])



In [304]:
# 858000
cell.x[858000]

IndexError: index 858000 is out of bounds for dimension 0 with size 858000

In [305]:
cell.x[858000-1]

tensor([ 4.2288, -1.0000,  2.0000,  0.0000], dtype=torch.float64)