In [1]:
import time
import pickle
import torch
import copy
import torch.nn          as nn
import numpy             as np
import pandas            as pd
import matplotlib.pyplot as plt 
import seaborn           as sns

from typing           import List
from torch.utils.data import Dataset, DataLoader
from torch_geometric.loader import DataLoader as PyG_Dataloader

from config import (
    PATH_TO_FEATURES,
    PATH_TO_SAVED_DRUG_FEATURES,
    PATH_SUMMARY_DATASETS
)

torch.manual_seed(42)
sns.set_theme(style="white")

---

# Experiments on the `TabTab` approach

In this notebook we are going to expirment the approach of 
- having the cell-line branch using tabular input (`Tab`)
- having the drug branch using tabular input (`Tab`).

## Root datasets

- The final datasets have been created in `15_summary_datasets.ipynb`

In [153]:
# Reading the cell-line gene table.
with open(f'{PATH_SUMMARY_DATASETS}cell_line_gene_matrix.pkl', 'rb') as f:
    cl_gene_mat = pickle.load(f)

# Reading the drug SMILES fingerprint table.
with open(f'{PATH_SUMMARY_DATASETS}drug_smiles_fingerprints_matrix.pkl', 'rb') as f:
    drug_mat = pickle.load(f)

# Reading the drug response matrix.
with open(f'{PATH_SUMMARY_DATASETS}drug_response_matrix__gdsc2.pkl', 'rb') as f: 
    drm = pickle.load(f)  

In [154]:
print(f"Cell-line gene matrix\n{21*'='}")
assert len([col for col in cl_gene_mat.columns[1:] if '_gexpr' in col]) == \
    len([col for col in cl_gene_mat.columns[1:] if '_cnvg' in col]) == \
    len([col for col in cl_gene_mat.columns[1:] if '_cnvp' in col]) == \
    len([col for col in cl_gene_mat.columns[1:] if '_mut' in col])
cl_gene_mat.set_index('CELL_LINE_NAME', inplace=True)    
print(cl_gene_mat.shape)
print(f"unique cell-lines: {len(cl_gene_mat.index.unique())}")
cl_gene_mat.head(3)

Cell-line gene matrix
(806, 3432)
unique cell-lines: 806


Unnamed: 0_level_0,FBXL12_gexpr,PIN1_gexpr,PAK4_gexpr,GNA15_gexpr,ARPP19_gexpr,EAPP_gexpr,MOK_gexpr,MTHFD2_gexpr,TIPARP_gexpr,CASP3_gexpr,...,PDHX_mut,DFFB_mut,FOSL1_mut,ETS1_mut,EBNA1BP2_mut,MYL9_mut,MLLT11_mut,PFKL_mut,FGFR4_mut,SDHB_mut
CELL_LINE_NAME,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
22RV1,7.023759,6.067534,4.31875,3.261427,6.297582,8.313991,5.514912,10.594112,5.222366,6.635925,...,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0
23132-87,6.714387,5.695096,4.536146,3.295886,7.021037,8.50008,4.862145,10.609245,6.528668,7.238143,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
42-MG-BA,7.752402,5.475753,4.033714,3.176525,7.279671,8.013367,4.957332,11.266705,7.445954,6.312424,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [155]:
print(f"Drug SMILES fingerprint matrix\n{30*'='}")
drug_mat.set_index('DRUG_ID', inplace=True)
print(drug_mat.shape)
print(f"unique drugs: {len(drug_mat.index.unique())}")
drug_mat.head(3)

Drug SMILES fingerprint matrix
(152, 256)
unique drugs: 152


Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,246,247,248,249,250,251,252,253,254,255
DRUG_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1073,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
1910,1,1,0,0,0,0,0,0,0,0,...,1,0,0,0,0,1,0,0,0,1
1913,0,1,1,0,1,0,0,0,0,0,...,0,0,0,0,1,1,0,0,0,0


In [84]:
print(f"Drug response matrix\n{20*'='}")
print(drm.shape)
print(f"unique cell-lines: {len(drm.CELL_LINE_NAME.unique())}")
print(f"unique drugs     : {len(drm.DRUG_ID.unique())}")
drm.head(3)

Drug response matrix
(100972, 5)
unique cell-lines: 806
unique drugs     : 152


Unnamed: 0,CELL_LINE_NAME,DRUG_ID,DRUG_NAME,DATASET,LN_IC50
3441054,22RV1,1003,Camptothecin,GDSC2,-3.142631
3459252,22RV1,1004,Vinblastine,GDSC2,-4.459259
3508920,22RV1,1006,Cytarabine,GDSC2,3.826935


## Create `PyTorch` dataset

In [180]:
from torch_geometric.data import Dataset as PyGDataset

class TabTabDataset(PyGDataset):
    def __init__(self, cl_mat, drug_mat, drm):
        super().__init__()
        self.cl_mat = cl_mat
        self.drug_mat = drug_mat

        drm.reset_index(drop=True, inplace=True)
        self.cls = drm['CELL_LINE_NAME']
        self.drug_ids = drm['DRUG_ID']
        self.drug_names = drm['DRUG_NAME']
        self.ic50s = drm['LN_IC50']

    def __len__(self):
        return len(self.ic50s)

    def __getitem__(self, idx: int):
        """
        Returns a tuple of cell-line-gene features, drug smiles fingerprints 
        and the corresponding ln(IC50) values for a given index.

        Args:
            idx (`int`): Index to specify the row in the drug response matrix.  
        Returns
            `Tuple[np.ndarray, np.ndarray, np.float64]]`: Tuple of cell-line 
                gene feature values, drug SMILES fingerprints and the 
                corresponding ln(IC50) target values.
        """
        return (self.cl_mat.loc[self.cls.iloc[idx]].values, 
                self.drug_mat.loc[self.drug_ids.iloc[idx]].values,
                self.ic50s.iloc[idx])

    def print_summary(self):
        print(f"TabTabDataset Summary")
        print(21*'=')
        print(f"# observations :", len(self.ic50s))
        print(f"# cell-lines   :", len(np.unique(self.cls)))
        print(f"# drugs        :", len(np.unique(self.drug_names)))
        print(f"# genes        :", len([col for col in self.cl_mat.columns[1:] if '_cnvg' in col]))

In [181]:
tabtab_dataset = TabTabDataset(cl_mat=cl_gene_mat, drug_mat=drug_mat, drm=drm)
tabtab_dataset.print_summary()

TabTabDataset Summary
# observations : 100972
# cell-lines   : 806
# drugs        : 152
# genes        : 858


## Hyperparameters

In [182]:
class Args:
    def __init__(self, batch_size, lr, train_ratio, val_ratio, num_epochs):
        self.BATCH_SIZE = batch_size
        self.LR = lr
        self.TRAIN_RATIO = train_ratio
        self.TEST_VAL_RATIO = 1-self.TRAIN_RATIO
        self.VAL_RATIO = val_ratio
        self.NUM_EPOCHS = num_epochs
        self.RANDOM_SEED = 12345      

args = Args(batch_size=1_000, 
            lr=0.001, 
            train_ratio=0.8, 
            val_ratio=0.5, 
            num_epochs=100)


## Create `DataLoader` datasets

In [191]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

def _collate_tab_tab(samples):
    cls, drugs, ic50s = map(list, zip(*samples))
    cls = [torch.tensor(cl, dtype=torch.float64) for cl in cls]
    drugs = [torch.tensor(drug, dtype=torch.float64) for drug in drugs]
    # print("\nCELL-LINES: ", cls[0])
    # print("\nDRUG:", drugs[0])
    # print("\nIC50: ", ic50s[0])
    
    return torch.stack(cls, 0), torch.stack(drugs, 0), torch.tensor(ic50s)

def create_datasets(drm, cl_mat, drug_mat):
    train_set, test_val_set = train_test_split(drm, test_size=args.TEST_VAL_RATIO, random_state=args.RANDOM_SEED, stratify=drm['CELL_LINE_NAME'])
    test_set, val_set = train_test_split(test_val_set, test_size=args.VAL_RATIO, random_state=args.RANDOM_SEED, stratify=test_val_set['CELL_LINE_NAME'])

    print("train_set.shape:", train_set.shape)
    print("test_set.shape:", test_set.shape)
    print("val_set.shape:", val_set.shape)

    train_dataset = TabTabDataset(cl_mat=cl_mat, drug_mat=drug_mat, drm=train_set)
    test_dataset = TabTabDataset(cl_mat=cl_mat, drug_mat=drug_mat, drm=test_set)
    val_dataset = TabTabDataset(cl_mat=cl_mat, drug_mat=drug_mat, drm=val_set)

    print("\ntrain_dataset"); train_dataset.print_summary()
    print("\ntest_dataset"); test_dataset.print_summary()
    print("\nval_dataset"); val_dataset.print_summary()

    train_loader = DataLoader(dataset=train_dataset, batch_size=args.BATCH_SIZE, shuffle=True, collate_fn=_collate_tab_tab)
    test_loader = DataLoader(dataset=test_dataset, batch_size=args.BATCH_SIZE, shuffle=True, collate_fn=_collate_tab_tab)
    val_loader = DataLoader(dataset=val_dataset, batch_size=args.BATCH_SIZE, shuffle=True, collate_fn=_collate_tab_tab)

    return train_loader, test_loader, val_loader

train_loader, test_loader, val_loader = create_datasets(drm, cl_gene_mat, drug_mat)    

train_set.shape: (80777, 5)
test_set.shape: (10097, 5)
val_set.shape: (10098, 5)

train_dataset
TabTabDataset Summary
# observations : 80777
# cell-lines   : 806
# drugs        : 152
# genes        : 858

test_dataset
TabTabDataset Summary
# observations : 10097
# cell-lines   : 806
# drugs        : 151
# genes        : 858

val_dataset
TabTabDataset Summary
# observations : 10098
# cell-lines   : 806
# drugs        : 152
# genes        : 858


In [199]:
print("Number of batches per dataset:")
print(f"  train : {len(train_loader)}")
print(f"  test  : {len(test_loader)}")
print(f"  val   : {len(val_loader)}")

Number of batches per dataset:
  train : 81
  test  : 11
  val   : 11


In [198]:
for step, data in enumerate(train_loader):
    if (step > 2) & (step < len(train_loader)-1):
        if step % 10 == 0: 
            print("... step", step) 
        continue
    else:
        cl_mat, drug_mat, ic50s = data
        print(f'Step {step + 1}:')
        print(f'=======')    
        print(cl_mat.shape)
        print(drug_mat.shape)
        print(ic50s.shape)

Step 1:
torch.Size([1000, 3432])
torch.Size([1000, 256])
torch.Size([1000])
Step 2:
torch.Size([1000, 3432])
torch.Size([1000, 256])
torch.Size([1000])
Step 3:
torch.Size([1000, 3432])
torch.Size([1000, 256])
torch.Size([1000])
... step 10
... step 20
... step 30
... step 40
... step 50
... step 60
... step 70
Step 81:
torch.Size([777, 3432])
torch.Size([777, 256])
torch.Size([777])


## Model development