In [5]:
import time
import pickle
import torch
import copy
import torch.nn          as nn
import numpy             as np
import pandas            as pd
import matplotlib.pyplot as plt 
import seaborn           as sns

from tqdm                   import tqdm
from torch.utils.data       import Dataset, DataLoader
from torch_geometric.loader import DataLoader as PyG_Dataloader

from config import (
    PATH_TO_FEATURES,
    PATH_SUMMARY_DATASETS
)

WITHOUT_MISSING_FOLDER = '/without_missing/'

torch.manual_seed(42)
sns.set_theme(style="white")

---

# Experiments on the `GraphTab` approach

Here we used the sparsed graph dataset, which itself used the `combined_score` to select only gene-gene neighbor tuples with a score value of over 950.

In this notebook we are going to expirment the approach of 
- replacing the cell-line branch by a GNN and 
- having the drug branch using tabular input.

In [12]:
# Reading the drug response matrix.
# This dataset got developed in `15_summary_datasets.ipynb`
with open(f'{PATH_SUMMARY_DATASETS}{WITHOUT_MISSING_FOLDER}drug_response_matrix__gdsc2.pkl', 'rb') as f: 
    drug_cl = pickle.load(f)
print(drug_cl.shape)
print("Number of unique cell lines:", len(drug_cl.CELL_LINE_NAME.unique()))
print("Number of unique drug ids:", len(drug_cl.DRUG_ID.unique()))
print("Number of unique drug names:", len(drug_cl.DRUG_NAME.unique()))
drug_cl.head(3)

(91991, 5)
Number of unique cell lines: 732
Number of unique drug ids: 152
Number of unique drug names: 152


Unnamed: 0,CELL_LINE_NAME,DRUG_ID,DRUG_NAME,DATASET,LN_IC50
3441054,22RV1,1003,Camptothecin,GDSC2,-3.142631
3459252,22RV1,1004,Vinblastine,GDSC2,-4.459259
3508920,22RV1,1006,Cytarabine,GDSC2,3.826935


In [3]:
# The following are the gene-gene interaction graphs per cell-line were the gene-gene tuples
# have a `combined_score` > 0.95*1_000=950.
# This dataset got developed in `07_v2_graph_dataset.ipynb`
with open(f'{PATH_TO_FEATURES}cl_graphs_as_dict_SPARSE.pkl', 'rb') as f:
    cl_graphs_sparse = pd.read_pickle(f)
print(f"Number of cell-lines/graphs: {len(list(cl_graphs_sparse.keys()))}")
print(cl_graphs_sparse['22RV1'])

Number of cell-lines/graphs: 983
Data(x=[458, 4], edge_index=[2, 4760])


### Pre-process

In this sub-section we are going to
- [x] select only the cell-line in the gene-gene graph dictionary which are in the drug-response-matrix

In [10]:
cl_graphs = dict((cl, cl_graphs_sparse[cl]) for cl in drug_cl.CELL_LINE_NAME.unique().tolist())
print(f"Number of cell-lines/graphs: {len(list(cl_graphs.keys()))}")
print(cl_graphs['22RV1'])

Number of cell-lines/graphs: 732
Data(x=[458, 4], edge_index=[2, 4760])


In [11]:
fails = []
cls = len(list(cl_graphs.keys()))
for cl, G in cl_graphs.items():
    if not (G.edge_index.max() < G.num_nodes):
        fails.append(cl)
print(f"Failed for {len(fails)} ({100*len(fails)/cls:2.2f} %) out of all {cls} cell-lines.")
del cls

if len(fails) > 0:
    for cl, G in tqdm(cl_graphs.items()):
        mapping = {}
        mapped_edge_index = []
        for (src, dst) in G.edge_index.t().tolist():
            if src not in mapping:
                mapping[src] = len(mapping)
            if dst not in mapping:
                mapping[dst] = len(mapping)
            mapped_edge_index.append([mapping[src], mapping[dst]])
        edge_index = torch.tensor(mapped_edge_index).t().contiguous()
        cl_graphs[cl].edge_index = edge_index

for cl, G in cl_graphs.items():
    assert G.edge_index.max() < G.num_nodes, f'FAIL for cell-line: {cl}'
print("All cell-lines succeeded according to this issue: ")
print("https://github.com/pyg-team/pytorch_geometric/issues/4588")  

Failed for 732 (100.00 %) out of all 732 cell-lines.


100%|██████████| 732/732 [00:09<00:00, 74.03it/s]

All cell-lines succeeded according to this issue: 
https://github.com/pyg-team/pytorch_geometric/issues/4588





In [13]:
print(f"Shape: {drug_cl.shape}")
print(f"Number of different drug id's   : {len(np.unique(drug_cl.DRUG_ID.values))}")
print(f"Number of different drug name's : {len(np.unique(drug_cl.DRUG_NAME.values))}")
drug_cl.head(3)

Shape: (91991, 5)
Number of different drug id's   : 152
Number of different drug name's : 152


Unnamed: 0,CELL_LINE_NAME,DRUG_ID,DRUG_NAME,DATASET,LN_IC50
3441054,22RV1,1003,Camptothecin,GDSC2,-3.142631
3459252,22RV1,1004,Vinblastine,GDSC2,-4.459259
3508920,22RV1,1006,Cytarabine,GDSC2,3.826935


In [14]:
with open(f'{PATH_SUMMARY_DATASETS}drug_smiles_fingerprints_matrix.pkl', 'rb') as f:
    drug_name_smiles = pickle.load(f)
# drug_name_smiles.set_index(['drug_name'], inplace=True)
print(drug_name_smiles.shape)
# TODO: Note that yet there are some DRUG_NAME's which have >1 DRUG_ID
drug_name_smiles.head(5)

(152, 257)


Unnamed: 0,DRUG_ID,0,1,2,3,4,5,6,7,8,...,246,247,248,249,250,251,252,253,254,255
1,1073,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
810,1910,1,1,0,0,0,0,0,0,0,...,1,0,0,0,0,1,0,0,0,1
1562,1913,0,1,1,0,1,0,0,0,0,...,0,0,0,0,1,1,0,0,0,0
2314,1634,0,0,0,0,0,0,0,0,0,...,1,0,0,0,0,1,0,0,0,1
3044,2045,0,1,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0


---

## Dataset summary

We got three datasets now: 
- Drug response matrix
- fingerprint matrix
- cell-line graph dictionary

In [15]:
# Drug response matrix holding the ln(IC50) values for each cell-line drug tuple.
drug_response_matrix = copy.deepcopy(drug_cl)
print(drug_response_matrix.shape)
drug_response_matrix.sort_values(['DRUG_ID']).head(5)

(91991, 5)


Unnamed: 0,CELL_LINE_NAME,DRUG_ID,DRUG_NAME,DATASET,LN_IC50
3441054,22RV1,1003,Camptothecin,GDSC2,-3.142631
3445814,QGP-1,1003,Camptothecin,GDSC2,-0.534857
3446843,RC-K8,1003,Camptothecin,GDSC2,-3.324229
3447648,RCC-JW,1003,Camptothecin,GDSC2,-2.573593
3440830,CHP-134,1003,Camptothecin,GDSC2,-3.817771


In [16]:
# The drug matrix holding the fingerprints for the drugs in the drug response matrix.
fingerprints = copy.deepcopy(drug_name_smiles)
print(fingerprints.shape)
fingerprints.sort_values(['DRUG_ID']).head(5)

(152, 257)


Unnamed: 0,DRUG_ID,0,1,2,3,4,5,6,7,8,...,246,247,248,249,250,251,252,253,254,255
21273,1003,0,0,0,0,0,0,0,0,0,...,0,0,1,0,1,0,0,0,0,0
94088,1004,1,0,0,0,0,0,0,0,1,...,0,1,0,1,0,1,0,0,0,0
25052,1006,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,1,0,0,0,0
39836,1010,1,0,0,0,0,0,0,0,0,...,0,0,0,1,0,0,0,0,0,1
61065,1011,0,1,0,0,0,0,0,0,0,...,1,0,0,0,1,1,0,0,0,1


In [17]:
# Drugs as dictionary.
fingerprints_dict = fingerprints.set_index('DRUG_ID').T.to_dict('list')
print(len(fingerprints_dict.keys()))
print(len(fingerprints_dict[1003]))

152
256


In [18]:
# The cell-line graphs holding for each cell-line in the drug-response-matrix the corresponding 
# graph with the cell-line level gene features.
cell_line_graphs = copy.deepcopy(cl_graphs)
print(f"Number of cell-lines/graphs: {len(list(cell_line_graphs.keys()))}")
print(cell_line_graphs['22RV1'])

Number of cell-lines/graphs: 732
Data(x=[458, 4], edge_index=[2, 4760])


In [19]:
fails = []
cls = len(list(cell_line_graphs.keys()))
for cl, G in cell_line_graphs.items():
    if not(G.edge_index.max() < G.num_nodes):
        fails.append(cl)
print(f"Failed for {len(fails)} ({100*len(fails)/cls:2.2f} %) out of all {cls} cell-lines.")
del cls

c = 0
sum_per = []
for k, v in cell_line_graphs.items():
    if v.x.isnan().any(): 
        sum_per.append(v.x.isnan().sum())
        c += 1
print(c)

Failed for 0 (0.00 %) out of all 732 cell-lines.
0


---

## Create PyTorch dataset

dataset should be 

| Tuple Datapoint | Cell Branch Input | Drug Branch Input | Target | Example Dataset |
| --------------- | ----------------- | ----------------- | ------ | --------------- |
| `(cl_1, drug_1)`| $\text{graph}_{\text{cl}_1}$ | $\text{smiles}_{\text{drug}_1}$ | $ln(IC50)_{\text{cl}_1\text{drug}_1}$ | Train |
| `(cl_1, ...   )`| $\text{graph}_{\text{cl}_1}$ | $\text{smiles}_{\text{drug}_{...}}$ | $ln(IC50)_{\text{cl}_1\text{drug}_{...}}$ | Train | 
| `(cl_1, drug_m)`| $\text{graph}_{\text{cl}_1}$ | $\text{smiles}_{\text{drug}_m}$ | $ln(IC50)_{\text{cl}_1\text{drug}_m}$ | Test | 
| `(cl_i, drug_1)`| $\text{graph}_{\text{cl}_i}$ | $\text{smiles}_{\text{drug}_1}$ | $ln(IC50)_{\text{cl}_i\text{drug}_1}$ | Test |
| `(cl_i, ...   )`| $\text{graph}_{\text{cl}_i}$ | $\text{smiles}_{\text{drug}_{...}}$ | $ln(IC50)_{\text{cl}_i\text{drug}_{...}}$ | Train | 
| `(cl_i, drug_m)`| $\text{graph}_{\text{cl}_i}$ | $\text{smiles}_{\text{drug}_m}$ | $ln(IC50)_{\text{cl}_i\text{drug}_m}$ | Train |
| `(cl_n, drug_1)`| $\text{graph}_{\text{cl}_n}$ | $\text{smiles}_{\text{drug}_1}$ | $ln(IC50)_{\text{cl}_n\text{drug}_1}$ | Test | 
| `(cl_n, ...   )`| $\text{graph}_{\text{cl}_n}$ | $\text{smiles}_{\text{drug}_{...}}$ | $ln(IC50)_{\text{cl}_n\text{drug}_{...}}$ | Test |
| `(cl_n, drug_m)`| $\text{graph}_{\text{cl}_n}$ | $\text{smiles}_{\text{drug}_m}$ | $ln(IC50)_{\text{cl}_n\text{drug}_m}$ | Train | 

In [20]:
from typing import Tuple
from torch_geometric.data import Data, Dataset

class GraphTabDataset(Dataset): 
    def __init__(self, cl_graphs, drugs, drug_response_matrix):
        super().__init__()

        # SMILES fingerprints of the drugs and cell-line graphs.
        self.drugs = drugs
        self.cell_line_graphs = cl_graphs

        # Lookup datasets for the response values.
        drug_response_matrix.reset_index(drop=True, inplace=True)
        self.cell_lines = drug_response_matrix['CELL_LINE_NAME']
        self.drug_ids = drug_response_matrix['DRUG_ID']
        self.drug_names = drug_response_matrix['DRUG_NAME']
        self.ic50s = drug_response_matrix['LN_IC50']

    def __len__(self):
        return len(self.ic50s)

    def __getitem__(self, idx: int):
        """
        Returns a tuple of cell-line, drug and the corresponding ln(IC50)
        value for a given index.

        Args:
            idx (`int`): Index to specify the row in the drug response matrix.  
        Returns:
            `Tuple[torch_geometric.data.data.Data, np.ndarray, np.float64]`:
            Tuple of a cell-line graph, drug SMILES fingerprint and the 
            corresponding ln(IC50) value.
        """
        return (self.cell_line_graphs[self.cell_lines.iloc[idx]], 
                self.drugs[self.drug_ids.iloc[idx]],
                self.ic50s.iloc[idx])

    def print_dataset_summary(self):
        print(f"GraphTabDataset Summary")
        print(f"{23*'='}")
        print(f"# observations : {len(self.ic50s)}")
        print(f"# cell-lines   : {len(np.unique(self.cell_lines))}")
        print(f"# drugs        : {len(np.unique(self.drug_names))}")
        print(f"# genes        : {self.cell_line_graphs[next(iter(self.cell_line_graphs))].x.shape[0]}")

In [21]:
graph_tab_dataset = GraphTabDataset(cl_graphs=cell_line_graphs,
                                    drugs=fingerprints_dict,
                                    drug_response_matrix=drug_response_matrix)

graph_tab_dataset.print_dataset_summary()                                    

GraphTabDataset Summary
# observations : 91991
# cell-lines   : 732
# drugs        : 152
# genes        : 458


In [22]:
print(graph_tab_dataset.cell_lines.iloc[1])
print(graph_tab_dataset.drug_ids.iloc[1])
print(graph_tab_dataset.drug_names.iloc[1])
print(graph_tab_dataset.cell_line_graphs['22RV1'])
print(graph_tab_dataset.ic50s.iloc[1])
print(graph_tab_dataset[1])
drug_response_matrix.iloc[1]

22RV1
1004
Vinblastine
Data(x=[458, 4], edge_index=[2, 4760])
-4.459259
(Data(x=[458, 4], edge_index=[2, 4760]), [1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0], -4.459259)


CELL_LINE_NAME          22RV1
DRUG_ID                  1004
DRUG_NAME         Vinblastine
DATASET                 GDSC2
LN_IC50             -4.459259
Name: 1, dtype: object

In [23]:
random_seed = 12345

BATCH_SIZE = 10 # TODO: tune batch_size
LR = 0.001 # TODO: tune this or implement lr decay
TRAIN_RATIO = 0.8 
TEST_VAL_RATIO = 1-TRAIN_RATIO # How much of all data is for the test and validation set.
VAL_RATIO = 0.5 # How much of the of the test and validation set is only for validation.
NUM_EPOCHS = 100

In [24]:
from sklearn.model_selection import train_test_split
from torch_geometric.data import Batch
# from torch.utils.data import DataLoader
from torch_geometric.loader import DataLoader

def _collate_graph_tab(samples):
    """
    Collates a list of cell-line graphs, SMILES drug fingerprints and 
    ln(IC50) values to pytorch handable formats.

    Args: 
        samples (`List[Tuple[torch.Data, List[int], float]]`): List of tuples of 
            cell-line graphs, SMILES drug fingerprints and ln(IC50) values.
            Example:
            >>> [(Data(x=[858, 4], edge_index=[2, 83126]), [0, 0, 1, ..., 0], 4.792643),
                 (Data(x=[858, 4], edge_index=[2, 83126]), [1, 0, 1, ..., 1], 5.857639),
                 ...]
            
    Returns:
        `DataBatch`: All graphs in the batch as one big (disconnected) graph.
            Example: 
            >>> DataBatch(x=[858000, 4], edge_index=[2, 83126000], batch=[858000], ptr=[1001])
                # x=[858*BATCH_SIZE, 4], edge_index=[2, 83126*BATCH_SIZE], batch=[858*BATCH_SIZE]
        `torch.tensor(List[List[int]])`: tensor of list of drug fingerprints.
            Example:
            >>> torch.tensor([[0,0,1,...,0], 
                              [1,0,1,...,1],
                              ...])
        `torch.tensor(float)`: tensor of ln(IC50)'s.
            Example: 
            >>> torch.tensor(4.792643, 5.857639, ...)
    """
    cells, drugs, targets = map(list, zip(*samples))
    drugs = [torch.tensor(drug_fp, dtype=torch.float64) for drug_fp in drugs] # list of fingerprint tensors

    return Batch.from_data_list(cells), torch.stack(drugs, 0), torch.tensor(targets)
    # return "hallo", torch.stack(drugs, 0), torch.tensor(targets)

def create_datasets(drm, cl_dict, drug_dict):
    print(f"Full     shape: {drm.shape}")
    train_set, test_val_set = train_test_split(drm, 
                                               test_size=TEST_VAL_RATIO, 
                                               random_state=random_seed,
                                               stratify=drm['CELL_LINE_NAME'])
    test_set, val_set = train_test_split(test_val_set,
                                         test_size=VAL_RATIO,
                                         random_state=random_seed,
                                         stratify=test_val_set['CELL_LINE_NAME'])
    print(f"train    shape: {train_set.shape}")
    print(f"test_val shape: {test_val_set.shape}")
    print(f"test     shape: {test_set.shape}")
    print(f"val      shape: {val_set.shape}")

    train_dataset = GraphTabDataset(cl_graphs=cl_dict,
                                    drugs=drug_dict,
                                    drug_response_matrix=train_set)
    test_dataset = GraphTabDataset(cl_graphs=cl_dict,
                                   drugs=drug_dict,
                                   drug_response_matrix=test_set)
    val_dataset = GraphTabDataset(cl_graphs=cl_dict,
                                  drugs=drug_dict,
                                  drug_response_matrix=val_set)

    print("\ntrain_dataset:")
    train_dataset.print_dataset_summary()
    print("\n\ntest_dataset:")
    test_dataset.print_dataset_summary()
    print("\n\nval_dataset:")
    val_dataset.print_dataset_summary()

    # TODO: try out different `num_workers`.
    train_loader = DataLoader(dataset=train_dataset, 
                              batch_size=BATCH_SIZE,
                              shuffle=True)
                              #collate_fn=_collate_graph_tab)
    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=BATCH_SIZE,
                             shuffle=True)
                             #collate_fn=_collate_graph_tab)
    val_loader = DataLoader(dataset=val_dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=True)
                            #collate_fn=_collate_graph_tab)

    return train_loader, test_loader, val_loader

train_loader, test_loader, val_loader = create_datasets(drug_response_matrix, 
                                                        cl_graphs,
                                                        fingerprints_dict)

Full     shape: (91991, 5)
train    shape: (73592, 5)
test_val shape: (18399, 5)
test     shape: (9199, 5)
val      shape: (9200, 5)

train_dataset:
GraphTabDataset Summary
# observations : 73592
# cell-lines   : 732
# drugs        : 152
# genes        : 458


test_dataset:
GraphTabDataset Summary
# observations : 9199
# cell-lines   : 732
# drugs        : 152
# genes        : 458


val_dataset:
GraphTabDataset Summary
# observations : 9200
# cell-lines   : 732
# drugs        : 152
# genes        : 458


In [25]:
for step, data in enumerate(train_loader):
    cell, drugs, targets = data
    print(f'Step {step + 1}:')
    print(f'=======')
    print(f'Number of graphs in the current batch: {cell.num_graphs}')
    print(cell)
    # print(drugs)
    # print(targets)
    print()

    if step == 3:
        break

Step 1:
Number of graphs in the current batch: 10
DataBatch(x=[4580, 4], edge_index=[2, 47600], batch=[4580], ptr=[11])

Step 2:
Number of graphs in the current batch: 10
DataBatch(x=[4580, 4], edge_index=[2, 47600], batch=[4580], ptr=[11])

Step 3:
Number of graphs in the current batch: 10
DataBatch(x=[4580, 4], edge_index=[2, 47600], batch=[4580], ptr=[11])

Step 4:
Number of graphs in the current batch: 10
DataBatch(x=[4580, 4], edge_index=[2, 47600], batch=[4580], ptr=[11])



In [26]:
# for batch_cell_graph, batch_drugs, batch_ic50s in train_loader
print("Number of batches per dataset:")
print(f"  train : {len(train_loader)}")
print(f"  test  : {len(test_loader)}")
print(f"  val   : {len(val_loader)}")

Number of batches per dataset:
  train : 7360
  test  : 920
  val   : 920


By using `Batch.from_data_list(cells)` we batched all graphs in a batch into a single giant graph.

__TODO__: 
- [ ] We can also use the class `torch_geometric.data.DataLoader` directly to do this.
- example notebook: https://colab.research.google.com/drive/1I8a0DfQ3fI7Njc62__mVXUlcAleUclnb?usp=sharing#scrollTo=G-DwBYkquRUN

- PyTorch.data doc: https://pytorch.org/docs/stable/data.html

In [22]:
for data in train_loader: 
    cell, drug, ic = data
    print(drug)
    print(drug[0][0])
    break

[tensor([0, 0, 1, 0, 0, 0, 0, 0, 0, 0]), tensor([1, 0, 0, 0, 0, 1, 1, 0, 0, 1]), tensor([0, 0, 1, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 1]), tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), tensor([1, 0, 1, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), tensor([1, 0, 0, 0, 0, 0, 1, 0, 0, 0]), tensor([0, 1, 0, 0, 1, 0, 0, 0, 0, 0]), tensor([0, 1, 0, 0, 1, 0, 0, 0, 0, 0]), tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 1, 0, 0, 0, 1, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 1, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1]), tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0]), tensor([0, 0, 1, 0, 0, 1, 0, 0, 0, 0]), tensor([1, 1, 1, 0, 0, 0, 0, 0, 0, 1]), tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 1, 1, 0, 0, 0, 0]),

In [23]:
Batch.to_data_list(cell)[:10]

[Data(x=[858, 4], edge_index=[2, 83126]),
 Data(x=[858, 4], edge_index=[2, 83126]),
 Data(x=[858, 4], edge_index=[2, 83126]),
 Data(x=[858, 4], edge_index=[2, 83126]),
 Data(x=[858, 4], edge_index=[2, 83126]),
 Data(x=[858, 4], edge_index=[2, 83126]),
 Data(x=[858, 4], edge_index=[2, 83126]),
 Data(x=[858, 4], edge_index=[2, 83126]),
 Data(x=[858, 4], edge_index=[2, 83126]),
 Data(x=[858, 4], edge_index=[2, 83126])]

## Model development

In [122]:
from tqdm import tqdm
from time import sleep
import torch.nn.functional as F
from sklearn.metrics import r2_score, mean_absolute_error
from scipy.stats import pearsonr

class BuildModel():
    def __init__(self, model, criterion, optimizer, num_epochs, 
        train_loader, test_loader, val_loader, device):
        self.train_losses = []
        self.test_losses = []
        self.val_losses = []
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.val_loader = val_loader
        self.num_epochs = num_epochs
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device

    def train(self, loader): 
        train_epoch_losses, val_epoch_losses = [], []
        all_batch_losses = [] # TODO: this is just for monitoring
        n_batches = len(loader)

        self.model = self.model.float() # TODO: maybe remove
        for epoch in range(self.num_epochs):
            self.model.train()
            batch_losses = []
            for i, data in enumerate(tqdm(loader, desc='Iteration')):
                sleep(0.01)
                cell, drug, targets = data
                drug = torch.stack(drug, 0).transpose(1, 0) # Note that this is only neede when geometric 
                                                            # Dataloader is used and no collate.
                cell, drug, targets = cell.to(device), drug.to(device), targets.to(device)

                self.optimizer.zero_grad()

                #print('cell.shape    : ', cell.size)
                # print('drug.shape    : ', drug.shape)
                # print('targets.size  : ', targets.shape)

                # Models predictions of the ic50s for a batch of cell-lines and drugs
                preds = self.model(cell, drug.float()).unsqueeze(1)
                # print(100*"=")
                # print(targets)
                # print(targets.view(-1, 1))
                loss = self.criterion(preds, targets.view(-1, 1).float()) # =train_loss
                batch_losses.append(loss)

                loss.backward()
                self.optimizer.step()

            all_batch_losses.append(batch_losses) # TODO: this is just for monitoring
            total_epoch_loss = sum(batch_losses)
            train_epoch_losses.append(total_epoch_loss / n_batches)

            mse, _, _, _, _ = self.validate(self.val_loader)
            val_epoch_losses.append(mse)

            print("=====Epoch ", epoch)
            print(f"Train      | MSE: {train_epoch_losses[-1]:2.5f}")
            print(f"Validation | MSE: {mse:2.5f}")

        return train_epoch_losses, val_epoch_losses            

    def validate(self, loader):
        self.model.eval()
        y_true, y_pred = [], []
        total_loss = 0
        with torch.no_grad():
            for data in tqdm(loader, desc='Iter', position=0, leave=True):
                sleep(0.01)
                cl, dr, ic50 = data
                dr = torch.stack(dr, 0).transpose(1, 0)

                preds = self.model(cl, dr.float()).unsqueeze(1)
                ic50 = ic50.to(self.device)
                total_loss += self.criterion(preds, ic50.view(-1,1).float())
                # total_loss += F.mse_loss(preds, ic50.view(-1, 1).float(), reduction='sum')
                y_true.append(ic50.view(-1, 1))
                y_pred.append(preds)
        
        y_true = torch.cat(y_true, dim=0)
        y_pred = torch.cat(y_pred, dim=0)
        mse = total_loss / len(loader)
        rmse = torch.sqrt(mse)
        mae = mean_absolute_error(y_true.cpu(), y_pred.cpu())
        r2 = r2_score(y_true.cpu(), y_pred.cpu())
        pearson_corr_coef, _ = pearsonr(y_true.cpu().numpy().flatten(), 
                                        y_pred.cpu().numpy().flatten())

        return mse, rmse, mae, r2, pearson_corr_coef

In [123]:
%load_ext autoreload
%autoreload
# from v3_GCN import GraphTab_v1
# from my_utils.model_helpers import train_and_test_model
from torch_geometric.nn import Sequential, GCNConv, global_mean_pool, global_max_pool


class GraphTab_v1(torch.nn.Module):
    def __init__(self):
        super(GraphTab_v1, self).__init__()
        torch.manual_seed(12345)
        # Cell-line graph branch. Obtains node embeddings.
        # https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.sequential.Sequential
        self.cell_emb = Sequential('x, edge_index, batch', 
            [
                (GCNConv(in_channels=4, out_channels=256), 'x, edge_index -> x1'), # TODO: GATConv() vs GCNConv()
                nn.ReLU(inplace=True),
                ## nn.BatchNorm1d(num_features=128),
                ## nn.Dropout(self.dropout_p),
                (GCNConv(in_channels=256, out_channels=256), 'x1, edge_index -> x2'),
                nn.ReLU(inplace=True),
                (global_mean_pool, 'x2, batch -> x3'), 
                # Start embedding
                nn.Linear(256, 128),
                nn.BatchNorm1d(128),
                nn.ReLU(),
                nn.Dropout(p=0.1),
                nn.Linear(128, 128),
                nn.ReLU()
            ]
        )

        self.drug_emb = nn.Sequential(
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(p=0.1),
            nn.Linear(128, 128),
            nn.BatchNorm1d(128),
            nn.ReLU()          
        )

        self.fcn = nn.Sequential(
            nn.Linear(2*128, 128),
            nn.BatchNorm1d(128),
            nn.ELU(),
            nn.Dropout(p=0.1),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ELU(),
            nn.Dropout(p=0.1),
            nn.Linear(64, 1)
        )

    def forward(self, cell, drug):
        drug_emb = self.drug_emb(drug)
        cell_emb = self.cell_emb(cell.x.float(), cell.edge_index, cell.batch)
        concat = torch.cat([cell_emb, drug_emb], -1)
        y_pred = self.fcn(concat)
        y_pred = y_pred.reshape(y_pred.shape[0])
        return y_pred

torch.manual_seed(random_seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

model = GraphTab_v1().to(device)
loss_func = nn.MSELoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=LR) # TODO: include weight_decay of lr

build_model = BuildModel(model=model,
                         criterion=loss_func,
                         optimizer=optimizer,
                         num_epochs=50,
                         train_loader=train_loader,
                         test_loader=test_loader,
                         val_loader=val_loader, 
                         device=device)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
device: cpu


In [124]:
train_losses, val_losses = build_model.train(build_model.train_loader)

Iteration: 100%|██████████| 7360/7360 [31:12<00:00,  3.93it/s]
Iter: 100%|██████████| 920/920 [02:03<00:00,  7.44it/s]


=====Epoch  0
Train      | MSE: 3.19222
Validation | MSE: 2.41571


Iteration: 100%|██████████| 7360/7360 [30:52<00:00,  3.97it/s]
Iter: 100%|██████████| 920/920 [02:09<00:00,  7.11it/s]


=====Epoch  1
Train      | MSE: 2.79864
Validation | MSE: 2.43306


Iteration: 100%|██████████| 7360/7360 [30:58<00:00,  3.96it/s]
Iter: 100%|██████████| 920/920 [02:04<00:00,  7.38it/s]


=====Epoch  2
Train      | MSE: 2.72068
Validation | MSE: 2.26106


Iteration: 100%|██████████| 7360/7360 [29:49<00:00,  4.11it/s]
Iter: 100%|██████████| 920/920 [02:10<00:00,  7.05it/s]


=====Epoch  3
Train      | MSE: 2.58156
Validation | MSE: 2.27341


Iteration: 100%|██████████| 7360/7360 [30:45<00:00,  3.99it/s]
Iter: 100%|██████████| 920/920 [02:09<00:00,  7.10it/s]


=====Epoch  4
Train      | MSE: 2.49507
Validation | MSE: 3.44554


Iteration: 100%|██████████| 7360/7360 [30:08<00:00,  4.07it/s]  
Iter: 100%|██████████| 920/920 [02:12<00:00,  6.96it/s]


=====Epoch  5
Train      | MSE: 2.43572
Validation | MSE: 3.39140


Iteration:  20%|█▉        | 1439/7360 [06:30<28:53,  3.42it/s]

### Run Only With a Sample

In [120]:
from torch_geometric.loader import DataLoader

sample = drug_response_matrix.sample(1_000)
train_set, test_val_set = train_test_split(sample, test_size=0.8, random_state=random_seed)
sample_dataset = GraphTabDataset(cl_graphs=cl_graphs, drugs=fingerprints_dict, drug_response_matrix=train_set)
print("\ntrain_dataset:")
sample_dataset.print_dataset_summary()
sample_loader = DataLoader(dataset=sample_dataset, batch_size=2, shuffle=True)


train_dataset:
GraphTabDataset Summary
# observations : 200
# cell-lines   : 175
# drugs        : 108
# genes        : 458


In [121]:
train_losses, val_losses = build_model.train(sample_loader)

Iteration: 100%|██████████| 100/100 [00:06<00:00, 16.48it/s]
Iter: 100%|██████████| 920/920 [02:17<00:00,  6.71it/s]


=====Epoch  0
Train      | MSE: 13.98202
Validation | MSE: 11.01190


Iteration: 100%|██████████| 100/100 [00:06<00:00, 15.67it/s]
Iter: 100%|██████████| 920/920 [02:16<00:00,  6.74it/s]


=====Epoch  1
Train      | MSE: 11.09647
Validation | MSE: 9.97083


Iteration: 100%|██████████| 100/100 [00:06<00:00, 15.81it/s]
Iter: 100%|██████████| 920/920 [02:24<00:00,  6.37it/s]


=====Epoch  2
Train      | MSE: 8.56703
Validation | MSE: 8.11504


Iteration: 100%|██████████| 100/100 [00:06<00:00, 15.24it/s]
Iter: 100%|██████████| 920/920 [02:33<00:00,  5.99it/s]


=====Epoch  3
Train      | MSE: 6.49640
Validation | MSE: 7.55463


Iteration: 100%|██████████| 100/100 [00:06<00:00, 15.81it/s]
Iter: 100%|██████████| 920/920 [02:23<00:00,  6.42it/s]


=====Epoch  4
Train      | MSE: 6.88781
Validation | MSE: 7.54679


Iteration: 100%|██████████| 100/100 [00:06<00:00, 16.33it/s]
Iter: 100%|██████████| 920/920 [02:20<00:00,  6.56it/s]


=====Epoch  5
Train      | MSE: 7.12907
Validation | MSE: 7.46634


Iteration: 100%|██████████| 100/100 [00:06<00:00, 15.67it/s]
Iter: 100%|██████████| 920/920 [02:20<00:00,  6.57it/s]


=====Epoch  6
Train      | MSE: 6.83407
Validation | MSE: 10.60364


Iteration: 100%|██████████| 100/100 [00:06<00:00, 15.90it/s]
Iter: 100%|██████████| 920/920 [02:22<00:00,  6.48it/s]


=====Epoch  7
Train      | MSE: 6.86842
Validation | MSE: 7.34576


Iteration: 100%|██████████| 100/100 [00:06<00:00, 15.91it/s]
Iter: 100%|██████████| 920/920 [02:20<00:00,  6.53it/s]


=====Epoch  8
Train      | MSE: 6.24524
Validation | MSE: 7.07446


Iteration: 100%|██████████| 100/100 [00:06<00:00, 15.86it/s]
Iter: 100%|██████████| 920/920 [02:14<00:00,  6.85it/s]


=====Epoch  9
Train      | MSE: 6.77860
Validation | MSE: 7.24841


Iteration: 100%|██████████| 100/100 [00:06<00:00, 16.56it/s]
Iter: 100%|██████████| 920/920 [02:09<00:00,  7.10it/s]


=====Epoch  10
Train      | MSE: 6.18162
Validation | MSE: 6.95470


Iteration: 100%|██████████| 100/100 [00:06<00:00, 16.57it/s]
Iter: 100%|██████████| 920/920 [02:09<00:00,  7.08it/s]


=====Epoch  11
Train      | MSE: 6.43730
Validation | MSE: 6.98433


Iteration: 100%|██████████| 100/100 [00:06<00:00, 16.45it/s]
Iter: 100%|██████████| 920/920 [02:08<00:00,  7.15it/s]


=====Epoch  12
Train      | MSE: 5.98261
Validation | MSE: 7.09261


Iteration: 100%|██████████| 100/100 [00:06<00:00, 16.22it/s]
Iter: 100%|██████████| 920/920 [02:08<00:00,  7.14it/s]


=====Epoch  13
Train      | MSE: 6.42925
Validation | MSE: 7.40595


Iteration: 100%|██████████| 100/100 [00:06<00:00, 16.47it/s]
Iter: 100%|██████████| 920/920 [02:05<00:00,  7.31it/s]


=====Epoch  14
Train      | MSE: 6.39245
Validation | MSE: 7.20384


Iteration: 100%|██████████| 100/100 [00:06<00:00, 16.56it/s]
Iter: 100%|██████████| 920/920 [02:07<00:00,  7.23it/s]


=====Epoch  15
Train      | MSE: 6.58797
Validation | MSE: 7.04990


Iteration: 100%|██████████| 100/100 [00:06<00:00, 16.19it/s]
Iter: 100%|██████████| 920/920 [02:19<00:00,  6.62it/s]


=====Epoch  16
Train      | MSE: 6.48583
Validation | MSE: 7.16788


Iteration: 100%|██████████| 100/100 [00:07<00:00, 13.69it/s]
Iter: 100%|██████████| 920/920 [02:25<00:00,  6.34it/s]


=====Epoch  17
Train      | MSE: 6.09398
Validation | MSE: 7.05967


Iteration: 100%|██████████| 100/100 [00:06<00:00, 15.39it/s]
Iter: 100%|██████████| 920/920 [02:22<00:00,  6.47it/s]


=====Epoch  18
Train      | MSE: 6.40673
Validation | MSE: 7.07927


Iteration: 100%|██████████| 100/100 [00:06<00:00, 16.17it/s]
Iter: 100%|██████████| 920/920 [02:18<00:00,  6.66it/s]


=====Epoch  19
Train      | MSE: 6.35017
Validation | MSE: 7.07385


In [34]:
for step, data in enumerate(train_loader):
    cell, drug, ic50 = data
    print(f'Step {step + 1}:')
    print(f'=======')
    print(f'Number of graphs in the current batch: {cell.num_graphs}')
    print(cell)
    print(ic50)
    print()

    if step == 1:
        break

Step 1:
Number of graphs in the current batch: 10
DataBatch(x=[8580, 4], edge_index=[2, 831260], batch=[8580], ptr=[11])
tensor([ 4.5438,  4.4568, -0.9637, -2.6823, -3.9976,  3.0180, -2.8586,  4.7727,
        -2.9860, -1.4190], dtype=torch.float64)

Step 2:
Number of graphs in the current batch: 10
DataBatch(x=[8580, 4], edge_index=[2, 831260], batch=[8580], ptr=[11])
tensor([ 3.1815,  2.9466,  0.6865,  4.0391,  0.9185,  4.0514,  2.4978,  6.8512,
        -3.0862,  4.6164], dtype=torch.float64)



In [304]:
# 858000
cell.x[858000]

IndexError: index 858000 is out of bounds for dimension 0 with size 858000

In [305]:
cell.x[858000-1]

tensor([ 4.2288, -1.0000,  2.0000,  0.0000], dtype=torch.float64)