In [1]:
import time
import pickle
import torch
import copy
import torch.nn          as nn
import numpy             as np
import pandas            as pd
import matplotlib.pyplot as plt 
import seaborn           as sns

from typing           import List
from torch.utils.data import Dataset, DataLoader
from torch_geometric.loader import DataLoader as PyG_Dataloader

from config import (
    PATH_TO_FEATURES,
    PATH_TO_SAVED_DRUG_FEATURES,
    PATH_SUMMARY_DATASETS
)

torch.manual_seed(42)
sns.set_theme(style="white")

---

# Experiments on the `TabTab` approach

In this notebook we are going to expirment the approach of 
- having the cell-line branch using tabular input (`Tab`)
- having the drug branch using tabular input (`Tab`).

## Root datasets

- The final datasets have been created in `15_summary_datasets.ipynb`

In [2]:
# Reading the cell-line gene table.
with open(f'{PATH_SUMMARY_DATASETS}cell_line_gene_matrix.pkl', 'rb') as f:
    cl_gene_mat = pickle.load(f)

# Reading the drug SMILES fingerprint table.
with open(f'{PATH_SUMMARY_DATASETS}drug_smiles_fingerprints_matrix.pkl', 'rb') as f:
    drug_mat = pickle.load(f)

# Reading the drug response matrix.
with open(f'{PATH_SUMMARY_DATASETS}drug_response_matrix__gdsc2.pkl', 'rb') as f: 
    drm = pickle.load(f)  

In [10]:
print(f"Cell-line gene matrix\n{21*'='}")
print(cl_gene_mat.shape)
print(f"unique cell-lines: {len(cl_gene_mat.CELL_LINE_NAME.unique())}")
cl_gene_mat.head(3)

Cell-line gene matrix
(806, 3433)
unique cell-lines: 806


Unnamed: 0,CELL_LINE_NAME,FBXL12_gexpr,PIN1_gexpr,PAK4_gexpr,GNA15_gexpr,ARPP19_gexpr,EAPP_gexpr,MOK_gexpr,MTHFD2_gexpr,TIPARP_gexpr,...,PDHX_mut,DFFB_mut,FOSL1_mut,ETS1_mut,EBNA1BP2_mut,MYL9_mut,MLLT11_mut,PFKL_mut,FGFR4_mut,SDHB_mut
0,22RV1,7.023759,6.067534,4.31875,3.261427,6.297582,8.313991,5.514912,10.594112,5.222366,...,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0
1,23132-87,6.714387,5.695096,4.536146,3.295886,7.021037,8.50008,4.862145,10.609245,6.528668,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,42-MG-BA,7.752402,5.475753,4.033714,3.176525,7.279671,8.013367,4.957332,11.266705,7.445954,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [12]:
print(f"Drug SMILES fingerprint matrix\n{30*'='}")
print(drug_mat.shape)
print(f"unique drugs: {len(drug_mat.DRUG_ID.unique())}")
drug_mat.head(3)

Drug SMILES fingerprint matrix
(152, 257)
unique drugs: 152


Unnamed: 0,DRUG_ID,0,1,2,3,4,5,6,7,8,...,246,247,248,249,250,251,252,253,254,255
1,1073,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
810,1910,1,1,0,0,0,0,0,0,0,...,1,0,0,0,0,1,0,0,0,1
1562,1913,0,1,1,0,1,0,0,0,0,...,0,0,0,0,1,1,0,0,0,0


In [14]:
print(f"Drug response matrix\n{20*'='}")
print(drm.shape)
print(f"unique cell-lines: {len(drm.CELL_LINE_NAME.unique())}")
print(f"unique drugs     : {len(drm.DRUG_ID.unique())}")
drm.head(3)

Drug response matrix
(100972, 5)
unique cell-lines: 806
unique drugs     : 152


Unnamed: 0,CELL_LINE_NAME,DRUG_ID,DRUG_NAME,DATASET,LN_IC50
3441054,22RV1,1003,Camptothecin,GDSC2,-3.142631
3459252,22RV1,1004,Vinblastine,GDSC2,-4.459259
3508920,22RV1,1006,Cytarabine,GDSC2,3.826935


## Create `PyTorch` dataset

In [None]:
from torch_geometric.data import Dataset as PyGDataset

class TabTabDataset(PyGDataset):
    def __init__(self, cl_mat, drug_mat, drm):
        super().__init__()
        self.cl_mat = cl_mat
        self.drug_mat = drug_mat

        drm.reset_index(drop=True, inplace=True)
        self.cls = drm['CELL_LINE_NAME']
        self.drug_ids = drm['DRUG_ID']
        self.drug_names = drm['DRUG_NAME']
        self.ic50s = drm['LN_IC50']

    def __len__(self):
        return len(self.ic50s)

    def __getitem__(self, idx: int):
        """
        Returns a tuple of cell-line-gene features, drug smiles fingerprints 
        and the corresponding ln(IC50) values for a given index.

        Args:
            idx (`int`): Index to specify the row in the drug response matrix.  
        Returns
            `Tuple[np.ndarray, np.ndarray, np.float64]]`: Tuple of cell-line 
                gene feature values, drug SMILES fingerprints and the 
                corresponding ln(IC50) target values.
        """
        return (self.cl_mat.loc[self.cls.iloc[idx]], 
                self.drug_mat.loc[self.drug_ids.iloc[idx]],
                self.ic50s.iloc[idx])

    