In [11]:
import time
import pickle
import torch
import copy
import torch.nn          as nn
import numpy             as np
import pandas            as pd
import matplotlib.pyplot as plt 
import seaborn           as sns

from tqdm                   import tqdm
from torch.utils.data       import Dataset, DataLoader
from torch_geometric.loader import DataLoader as PyG_Dataloader

from config import (
    PATH_TO_FEATURES,
    PATH_SUMMARY_DATASETS
)

IMGS_PATH = 'imgs'
NOTEBOOKS_SUMMARY_FOLDER = 'v3/'
WITHOUT_MISSING_FOLDER = '/without_missing/'

torch.manual_seed(42)
sns.set_theme(style="white")

---

# Experiments on the `GraphTab` approach

For this notebook we will use the updated cell-line gene-gene interaction dataset created in [`07_v3_graph_dataset.ipynb`](07_v3_graph_dataset.ipynb). The resulting dictionary lies in `~/datasets_for_model_building/summary_datasets/v3/cl_graphs_dict.pkl`. For gene selection we used a `combined_score` threshold of `700`. The bi-modal network approach we are using in this notebook is as follows:

- replacing the cell-line branch by a GNN and 
- having the drug branch using tabular input.

In [3]:
# Reading the drug response matrix.
# This dataset got developed in `15_summary_datasets.ipynb`
with open(f'{PATH_SUMMARY_DATASETS}{NOTEBOOKS_SUMMARY_FOLDER}drug_response_matrix__gdsc2.pkl', 'rb') as f: 
    drug_cl = pickle.load(f)
print(drug_cl.shape)
print("Number of unique cell lines:", len(drug_cl.CELL_LINE_NAME.unique()))
print("Number of unique drug ids:", len(drug_cl.DRUG_ID.unique()))
print("Number of unique drug names:", len(drug_cl.DRUG_NAME.unique()))
drug_cl.head(3)

NameError: name 'PATH_SUMMARY_DATASETS' is not defined

In [7]:
# The following are the gene-gene interaction graphs per cell-line were the gene-gene tuples
# have a `combined_score` > 0.70*1_000=700.
# This dataset got developed in `07_v3_graph_dataset.ipynb`
with open(f'{PATH_SUMMARY_DATASETS}{NOTEBOOKS_SUMMARY_FOLDER}cl_graphs_dict.pkl', 'rb') as f:
    cl_graphs_v3 = pd.read_pickle(f)
print(f"Number of cell-lines/graphs: {len(list(cl_graphs_v3.keys()))}")
print(cl_graphs_v3['22RV1'])

Number of cell-lines/graphs: 983
Data(x=[696, 4], edge_index=[2, 7794])


### Pre-process

In this sub-section we are going to
- [x] select only the cell-line in the gene-gene graph dictionary which are in the drug-response-matrix

In [8]:
cl_graphs = dict((cl, cl_graphs_v3[cl]) for cl in drug_cl.CELL_LINE_NAME.unique().tolist())
print(f"Number of cell-lines/graphs: {len(list(cl_graphs.keys()))}")
print(cl_graphs['22RV1'])

Number of cell-lines/graphs: 732
Data(x=[696, 4], edge_index=[2, 7794])


In [9]:
fails = []
cls = len(list(cl_graphs.keys()))
for cl, G in cl_graphs.items():
    if not (G.edge_index.max() < G.num_nodes):
        fails.append(cl)
print(f"Failed for {len(fails)} ({100*len(fails)/cls:2.2f} %) out of all {cls} cell-lines.")
del cls

if len(fails) > 0:
    for cl, G in tqdm(cl_graphs.items()):
        mapping = {}
        mapped_edge_index = []
        for (src, dst) in G.edge_index.t().tolist():
            if src not in mapping:
                mapping[src] = len(mapping)
            if dst not in mapping:
                mapping[dst] = len(mapping)
            mapped_edge_index.append([mapping[src], mapping[dst]])
        edge_index = torch.tensor(mapped_edge_index).t().contiguous()
        cl_graphs[cl].edge_index = edge_index

for cl, G in cl_graphs.items():
    assert G.edge_index.max() < G.num_nodes, f'FAIL for cell-line: {cl}'
print("All cell-lines succeeded according to this issue: ")
print("https://github.com/pyg-team/pytorch_geometric/issues/4588")  

Failed for 0 (0.00 %) out of all 732 cell-lines.
All cell-lines succeeded according to this issue: 
https://github.com/pyg-team/pytorch_geometric/issues/4588


In [11]:
with open(f'{PATH_SUMMARY_DATASETS}drug_smiles_fingerprints_matrix.pkl', 'rb') as f:
    drug_name_smiles = pickle.load(f)
# drug_name_smiles.set_index(['drug_name'], inplace=True)
print(drug_name_smiles.shape)
# TODO: Note that yet there are some DRUG_NAME's which have >1 DRUG_ID
drug_name_smiles.head(5)

(152, 257)


Unnamed: 0,DRUG_ID,0,1,2,3,4,5,6,7,8,...,246,247,248,249,250,251,252,253,254,255
1,1073,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
810,1910,1,1,0,0,0,0,0,0,0,...,1,0,0,0,0,1,0,0,0,1
1562,1913,0,1,1,0,1,0,0,0,0,...,0,0,0,0,1,1,0,0,0,0
2314,1634,0,0,0,0,0,0,0,0,0,...,1,0,0,0,0,1,0,0,0,1
3044,2045,0,1,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0


---

## Dataset summary

We got three datasets now: 
- Drug response matrix
- fingerprint matrix
- cell-line graph dictionary

In [12]:
# Drug response matrix holding the ln(IC50) values for each cell-line drug tuple.
drug_response_matrix = copy.deepcopy(drug_cl)
print(drug_response_matrix.shape)
drug_response_matrix.sort_values(['DRUG_ID']).head(5)

(91991, 5)


Unnamed: 0,CELL_LINE_NAME,DRUG_ID,DRUG_NAME,DATASET,LN_IC50
3441054,22RV1,1003,Camptothecin,GDSC2,-3.142631
3445814,QGP-1,1003,Camptothecin,GDSC2,-0.534857
3446843,RC-K8,1003,Camptothecin,GDSC2,-3.324229
3447648,RCC-JW,1003,Camptothecin,GDSC2,-2.573593
3440830,CHP-134,1003,Camptothecin,GDSC2,-3.817771


In [13]:
# The drug matrix holding the fingerprints for the drugs in the drug response matrix.
fingerprints = copy.deepcopy(drug_name_smiles)
print(fingerprints.shape)
fingerprints.sort_values(['DRUG_ID']).head(5)

(152, 257)


Unnamed: 0,DRUG_ID,0,1,2,3,4,5,6,7,8,...,246,247,248,249,250,251,252,253,254,255
21273,1003,0,0,0,0,0,0,0,0,0,...,0,0,1,0,1,0,0,0,0,0
94088,1004,1,0,0,0,0,0,0,0,1,...,0,1,0,1,0,1,0,0,0,0
25052,1006,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,1,0,0,0,0
39836,1010,1,0,0,0,0,0,0,0,0,...,0,0,0,1,0,0,0,0,0,1
61065,1011,0,1,0,0,0,0,0,0,0,...,1,0,0,0,1,1,0,0,0,1


In [14]:
# Drugs as dictionary.
fingerprints_dict = fingerprints.set_index('DRUG_ID').T.to_dict('list')
print("Number of drug fingerprints:", len(fingerprints_dict.keys()))
print("Length of each SMILES fingerprint:", len(fingerprints_dict[1003]))

Number of drug fingerprints: 152
Length of each SMILES fingerprint: 256


In [15]:
# The cell-line graphs holding for each cell-line in the drug-response-matrix the corresponding 
# graph with the cell-line level gene features.
cell_line_graphs = copy.deepcopy(cl_graphs)
print(f"Number of cell-lines/graphs: {len(list(cell_line_graphs.keys()))}")
print(cell_line_graphs['22RV1'])

Number of cell-lines/graphs: 732
Data(x=[696, 4], edge_index=[2, 7794])


In [16]:
fails = []
cls = len(list(cell_line_graphs.keys()))
for cl, G in cell_line_graphs.items():
    if not(G.edge_index.max() < G.num_nodes):
        fails.append(cl)
print(f"Failed for {len(fails)} ({100*len(fails)/cls:2.2f} %) out of all {cls} cell-lines.")
del cls

c = 0
sum_per = []
for k, v in cell_line_graphs.items():
    if v.x.isnan().any(): 
        sum_per.append(v.x.isnan().sum())
        c += 1
print(c)

Failed for 0 (0.00 %) out of all 732 cell-lines.
0


---

## Build PyTorch dataset

In this subsection we are going to create the dataset which holds the correct drug, cell-line graph and corresponding `ln(IC50)` value for a given index.

In [20]:
from torch_geometric.data import Dataset

class GraphTabDataset(Dataset): 
    def __init__(self, cl_graphs, drugs, drug_response_matrix):
        super().__init__()

        # SMILES fingerprints of the drugs and cell-line graphs.
        self.drugs = drugs
        self.cell_line_graphs = cl_graphs

        # Lookup datasets for the response values.
        drug_response_matrix.reset_index(drop=True, inplace=True)
        self.cell_lines = drug_response_matrix['CELL_LINE_NAME']
        self.drug_ids = drug_response_matrix['DRUG_ID']
        self.drug_names = drug_response_matrix['DRUG_NAME']
        self.ic50s = drug_response_matrix['LN_IC50']

    def __len__(self):
        return len(self.ic50s)

    def __getitem__(self, idx: int):
        """
        Returns a tuple of cell-line, drug and the corresponding ln(IC50)
        value for a given index.

        Args:
            idx (`int`): Index to specify the row in the drug response matrix.  
        Returns:
            `Tuple[torch_geometric.data.data.Data, np.ndarray, np.float64]`:
            Tuple of a cell-line graph, drug SMILES fingerprint and the 
            corresponding ln(IC50) value.
        """
        return (self.cell_line_graphs[self.cell_lines.iloc[idx]], 
                self.drugs[self.drug_ids.iloc[idx]],
                self.ic50s.iloc[idx])

    def print_dataset_summary(self):
        print(f"GraphTabDataset Summary")
        print(f"{23*'='}")
        print(f"# observations : {len(self.ic50s)}")
        print(f"# cell-lines   : {len(np.unique(self.cell_lines))}")
        print(f"# drugs        : {len(np.unique(self.drug_names))}")
        print(f"# genes        : {self.cell_line_graphs[next(iter(self.cell_line_graphs))].x.shape[0]}")

In [21]:
graph_tab_dataset = GraphTabDataset(cl_graphs=cell_line_graphs,
                                    drugs=fingerprints_dict,
                                    drug_response_matrix=drug_response_matrix)
graph_tab_dataset.print_dataset_summary()                                    

GraphTabDataset Summary
# observations : 91991
# cell-lines   : 732
# drugs        : 152
# genes        : 696


## Set hyperparameters

In this subsection we are setting the hyperparameters for the model building.

In [22]:
class Args:
    def __init__(self, batch_size, lr, train_ratio, val_ratio, num_epochs):
        self.BATCH_SIZE = batch_size
        self.LR = lr
        self.TRAIN_RATIO = train_ratio
        self.TEST_VAL_RATIO = 1-self.TRAIN_RATIO
        self.VAL_RATIO = val_ratio
        self.NUM_EPOCHS = num_epochs
        self.RANDOM_SEED = 12345      

args = Args(batch_size=1_000, 
            lr=0.0001, 
            train_ratio=0.8, 
            val_ratio=0.5, 
            num_epochs=5)

## Create pytorch geometric `DataLoader` datasets

In [23]:
from sklearn.model_selection import train_test_split
from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader as PyG_DataLoader


def create_datasets(drm, cl_graphs, drug_mat, args):
    print(f"Full     shape: {drm.shape}")
    train_set, test_val_set = train_test_split(drm, 
                                               test_size=args.TEST_VAL_RATIO, 
                                               random_state=args.RANDOM_SEED,
                                               stratify=drm['CELL_LINE_NAME'])
    test_set, val_set = train_test_split(test_val_set,
                                         test_size=args.VAL_RATIO,
                                         random_state=args.RANDOM_SEED,
                                         stratify=test_val_set['CELL_LINE_NAME'])
    print(f"train    shape: {train_set.shape}")
    print(f"test_val shape: {test_val_set.shape}")
    print(f"test     shape: {test_set.shape}")
    print(f"val      shape: {val_set.shape}")

    train_dataset = GraphTabDataset(cl_graphs=cl_graphs, drugs=drug_mat, drug_response_matrix=train_set)
    test_dataset = GraphTabDataset(cl_graphs=cl_graphs, drugs=drug_mat, drug_response_matrix=test_set)
    val_dataset = GraphTabDataset(cl_graphs=cl_graphs, drugs=drug_mat, drug_response_matrix=val_set)

    print("\ntrain_dataset:")
    train_dataset.print_dataset_summary()
    print("\n\ntest_dataset:")
    test_dataset.print_dataset_summary()
    print("\n\nval_dataset:")
    val_dataset.print_dataset_summary()

    # TODO: try out different `num_workers`.
    train_loader = PyG_DataLoader(dataset=train_dataset, batch_size=args.BATCH_SIZE, shuffle=True)
    test_loader = PyG_DataLoader(dataset=test_dataset, batch_size=args.BATCH_SIZE, shuffle=True)
    val_loader = PyG_DataLoader(dataset=val_dataset, batch_size=args.BATCH_SIZE, shuffle=True)

    return train_loader, test_loader, val_loader

train_loader, test_loader, val_loader = create_datasets(drug_response_matrix, cl_graphs, fingerprints_dict, args)

Full     shape: (91991, 5)
train    shape: (73592, 5)
test_val shape: (18399, 5)
test     shape: (9199, 5)
val      shape: (9200, 5)

train_dataset:
GraphTabDataset Summary
# observations : 73592
# cell-lines   : 732
# drugs        : 152
# genes        : 696


test_dataset:
GraphTabDataset Summary
# observations : 9199
# cell-lines   : 732
# drugs        : 152
# genes        : 696


val_dataset:
GraphTabDataset Summary
# observations : 9200
# cell-lines   : 732
# drugs        : 152
# genes        : 696


In [28]:
# for batch_cell_graph, batch_drugs, batch_ic50s in train_loader
print("Number of batches per dataset:")
print(f"  train : {len(train_loader)}")
print(f"  test  : {len(test_loader)}")
print(f"  val   : {len(val_loader)}")

Number of batches per dataset:
  train : 74
  test  : 10
  val   : 10


In [31]:
for step, data in enumerate(train_loader):
    if (step > 2) & (step < len(train_loader)-1):
        if step % 10 == 0: 
            print("... step", step) 
        continue
    else:    
        cl_graphs, drugs, targets = data
        print(f'Step {step + 1}:')
        print(f'=======')
        print("Number of cell-line graphs in the batch:", cl_graphs.num_graphs)
        print("Gene-gene interaction topology per batch:", cl_graphs)
        print("Number of ln(IC50) targets per batch:", targets.shape)

Step 1:
Number of cell-line graphs in the batch: 1000
Gene-gene interaction topology per batch: DataBatch(x=[696000, 4], edge_index=[2, 7794000], batch=[696000], ptr=[1001])
Number of ln(IC50) targets per batch: torch.Size([1000])
Step 2:
Number of cell-line graphs in the batch: 1000
Gene-gene interaction topology per batch: DataBatch(x=[696000, 4], edge_index=[2, 7794000], batch=[696000], ptr=[1001])
Number of ln(IC50) targets per batch: torch.Size([1000])
Step 3:
Number of cell-line graphs in the batch: 1000
Gene-gene interaction topology per batch: DataBatch(x=[696000, 4], edge_index=[2, 7794000], batch=[696000], ptr=[1001])
Number of ln(IC50) targets per batch: torch.Size([1000])
... step 10
... step 20
... step 30
... step 40
... step 50
... step 60
... step 70
Step 74:
Number of cell-line graphs in the batch: 592
Gene-gene interaction topology per batch: DataBatch(x=[412032, 4], edge_index=[2, 4614048], batch=[412032], ptr=[593])
Number of ln(IC50) targets per batch: torch.Size([

## Model development

In this subsection we are going to train and test the GraphTab model.

In [32]:
from tqdm import tqdm
from time import sleep
from sklearn.metrics import r2_score, mean_absolute_error
from scipy.stats import pearsonr

class BuildModel():
    def __init__(self, model, criterion, optimizer, num_epochs, 
        train_loader, test_loader, val_loader, device):
        self.train_losses = []
        self.test_losses = []
        self.val_losses = []
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.val_loader = val_loader
        self.num_epochs = num_epochs
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device

    def train(self, loader): 
        train_epoch_losses, val_epoch_losses = [], []
        train_epoch_rmse, val_epoch_rmse = [], []
        train_epoch_mae, val_epoch_mae = [], []
        train_epoch_r2, val_epoch_r2 = [], []
        train_epoch_pcorr, val_epoch_pcorr = [], []
        y_true, y_pred = [], []
        all_batch_losses = [] # TODO: this is just for monitoring
        n_batches = len(loader)

        self.model = self.model.float() # TODO: maybe remove
        for epoch in range(self.num_epochs):
            self.model.train()
            batch_losses = []
            for i, data in enumerate(tqdm(loader, desc='Iteration')):
                sleep(0.01)
                cell, drug, ic50s = data
                drug = torch.stack(drug, 0).transpose(1, 0) # Note that this is only neede when geometric 
                                                            # Dataloader is used and no collate.
                cell, drug, ic50s = cell.to(self.device), drug.to(self.device), ic50s.to(self.device)

                self.optimizer.zero_grad()

                # Models predictions of the ic50s for a batch of cell-lines and drugs
                preds = self.model(cell, drug.float()).unsqueeze(1)
                loss = self.criterion(preds, ic50s.view(-1, 1).float()) # =train_loss
                batch_losses.append(loss)

                
                y_true.append(ic50s.view(-1, 1))
                y_pred.append(preds)             

                loss.backward()
                self.optimizer.step()

            all_batch_losses.append(batch_losses) # TODO: this is just for monitoring
            total_epoch_loss = sum(batch_losses)
            train_epoch_losses.append(total_epoch_loss / n_batches)

            y_true = torch.cat(y_true, dim=0)
            y_pred = torch.cat(y_pred, dim=0)
            train_mse = train_epoch_losses[-1]
            train_epoch_rmse.append(torch.sqrt(train_mse))
            train_epoch_mae.append(mean_absolute_error(y_true.cpu(), y_pred.cpu()))
            train_epoch_r2.append(r2_score(y_true.cpu(), y_pred.cpu()))
            train_epoch_pcorr.append(pearsonr(y_true.cpu().numpy().flatten(), y_pred.cpu().numpy().flatten()))
                     
            mse, rmse, mae, r2, pcorr = self.validate(self.val_loader)
            val_epoch_losses.append(mse)
            val_epoch_rmse.append(rmse)
            val_epoch_mae.append(mae)
            val_epoch_r2.append(r2)
            val_epoch_pcorr.append(pcorr)

            print("=====Epoch ", epoch)
            print(f"Train      | MSE: {train_mse:2.5f}")
            print(f"Validation | MSE: {mse:2.5f}")

        performance_stats = {
            'train': {
                'mse': train_epoch_losses,
                'rmse': train_epoch_rmse,
                'mae': train_epoch_mae,
                'r2': train_epoch_r2,
                'pcorr': train_epoch_pcorr
            },
            'val': {
                'mse': val_epoch_losses,
                'rmse': val_epoch_rmse,
                'mae': val_epoch_mae,
                'r2': val_epoch_r2,
                'pcorr': val_epoch_pcorr
            }            
        }

        return performance_stats           

    def validate(self, loader):
        self.model.eval()
        y_true, y_pred = [], []
        total_loss = 0
        with torch.no_grad():
            for data in tqdm(loader, desc='Iter', position=0, leave=True):
                sleep(0.01)
                cl, dr, ic50 = data
                dr = torch.stack(dr, 0).transpose(1, 0)

                preds = self.model(cl, dr.float()).unsqueeze(1)
                ic50 = ic50.to(self.device)
                total_loss += self.criterion(preds, ic50.view(-1,1).float())
                # total_loss += F.mse_loss(preds, ic50.view(-1, 1).float(), reduction='sum')
                y_true.append(ic50.view(-1, 1))
                y_pred.append(preds)
        
        y_true = torch.cat(y_true, dim=0)
        y_pred = torch.cat(y_pred, dim=0)
        mse = total_loss / len(loader)
        rmse = torch.sqrt(mse)
        mae = mean_absolute_error(y_true.cpu(), y_pred.cpu())
        r2 = r2_score(y_true.cpu(), y_pred.cpu())
        pearson_corr_coef, _ = pearsonr(y_true.cpu().numpy().flatten(), 
                                        y_pred.cpu().numpy().flatten())

        return mse, rmse, mae, r2, pearson_corr_coef

In [33]:
%load_ext autoreload
%autoreload

from torch_geometric.nn import Sequential, GCNConv, global_mean_pool, global_max_pool


class GraphTab_v1(torch.nn.Module):
    def __init__(self):
        super(GraphTab_v1, self).__init__()

        # Cell-line graph branch. Obtains node embeddings.
        # https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.sequential.Sequential
        self.cell_emb = Sequential('x, edge_index, batch', 
            [
                (GCNConv(in_channels=4, out_channels=256), 'x, edge_index -> x1'), # TODO: GATConv() vs GCNConv()
                nn.ReLU(inplace=True),
                ## nn.BatchNorm1d(num_features=128),
                ## nn.Dropout(self.dropout_p),
                (GCNConv(in_channels=256, out_channels=256), 'x1, edge_index -> x2'),
                nn.ReLU(inplace=True),
                (global_mean_pool, 'x2, batch -> x3'), 
                # Start embedding
                nn.Linear(256, 128),
                nn.BatchNorm1d(128),
                nn.ReLU(),
                nn.Dropout(p=0.1),
                nn.Linear(128, 128),
                nn.ReLU()
            ]
        )

        self.drug_emb = nn.Sequential(
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(p=0.1),
            nn.Linear(128, 128),
            nn.BatchNorm1d(128),
            nn.ReLU()          
        )

        self.fcn = nn.Sequential(
            nn.Linear(2*128, 128),
            nn.BatchNorm1d(128),
            nn.ELU(),
            nn.Dropout(p=0.1),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ELU(),
            nn.Dropout(p=0.1),
            nn.Linear(64, 1)
        )

    def forward(self, cell, drug):
        drug_emb = self.drug_emb(drug)
        cell_emb = self.cell_emb(cell.x.float(), cell.edge_index, cell.batch)
        concat = torch.cat([cell_emb, drug_emb], -1)
        y_pred = self.fcn(concat)
        y_pred = y_pred.reshape(y_pred.shape[0])
        return y_pred

torch.manual_seed(args.RANDOM_SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

model = GraphTab_v1().to(device)
loss_func = nn.MSELoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=args.LR) # TODO: include weight_decay of lr

build_model = BuildModel(model=model,
                         criterion=loss_func,
                         optimizer=optimizer,
                         num_epochs=3,
                         train_loader=train_loader,
                         test_loader=test_loader,
                         val_loader=val_loader, 
                         device=device)

device: cpu


In [34]:
performance_stats = build_model.train(build_model.train_loader)

Iteration:   0%|          | 0/74 [00:00<?, ?it/s]

### Run Only With a Sample

In [13]:

import torch
import torch.nn as nn
from torch_geometric.nn import Sequential, GCNConv, global_mean_pool, global_max_pool


class GraphTab_v1(torch.nn.Module):
    def __init__(self):
        super(GraphTab_v1, self).__init__()

        # Cell-line graph branch. Obtains node embeddings.
        # https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.sequential.Sequential
        self.cell_emb = Sequential('x, edge_index, batch', 
            [
                (GCNConv(in_channels=4, out_channels=256), 'x, edge_index -> x1'), # TODO: GATConv() vs GCNConv()
                nn.ReLU(inplace=True),
                ## nn.BatchNorm1d(num_features=128),
                ## nn.Dropout(self.dropout_p),
                (GCNConv(in_channels=256, out_channels=256), 'x1, edge_index -> x2'),
                nn.ReLU(inplace=True),
                (global_mean_pool, 'x2, batch -> x3'), 
                # Start embedding
                nn.Linear(256, 128),
                nn.BatchNorm1d(128),
                nn.ReLU(),
                nn.Dropout(p=0.1),
                nn.Linear(128, 128),
                nn.ReLU()
            ]
        )

        self.drug_emb = nn.Sequential(
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(p=0.1),
            nn.Linear(128, 128),
            nn.BatchNorm1d(128),
            nn.ReLU()          
        )

        self.fcn = nn.Sequential(
            nn.Linear(2*128, 128),
            nn.BatchNorm1d(128),
            nn.ELU(),
            nn.Dropout(p=0.1),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ELU(),
            nn.Dropout(p=0.1),
            nn.Linear(64, 1)
        )

    def forward(self, cell, drug):
        drug_emb = self.drug_emb(drug)
        cell_emb = self.cell_emb(cell.x.float(), cell.edge_index, cell.batch)
        concat = torch.cat([cell_emb, drug_emb], -1)
        y_pred = self.fcn(concat)
        y_pred = y_pred.reshape(y_pred.shape[0])
        return y_pred


model = GraphTab_v1()
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)

checkpoint = torch.load(f'{PATH_SUMMARY_DATASETS}GraphTab/v3/model_performance')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
train_perf = checkpoint['train_performances']
val_perf = checkpoint['val_performances']
print(train_perf)

# model.eval()

{'mse': [tensor(14.3344, requires_grad=True)], 'rmse': [tensor(3.7861, requires_grad=True)], 'mae': [3.3293848], 'r2': [-1.195433379156095], 'pcorr': [(0.0004950899936630969, 0.9944485685845041)]}


In [1]:
from torch_geometric.loader import DataLoader

sample = drug_response_matrix.sample(1_000)
train_set, test_val_set = train_test_split(sample, test_size=0.8, random_state=42)
sample_dataset = GraphTabDataset(cl_graphs=cl_graphs, drugs=fingerprints_dict, drug_response_matrix=train_set)
print("\ntrain_dataset:")
sample_dataset.print_dataset_summary()
sample_loader = DataLoader(dataset=sample_dataset, batch_size=2, shuffle=True)

NameError: name 'drug_response_matrix' is not defined

In [121]:
train_losses, val_losses = build_model.train(sample_loader)

Iteration: 100%|██████████| 100/100 [00:06<00:00, 16.48it/s]
Iter: 100%|██████████| 920/920 [02:17<00:00,  6.71it/s]


=====Epoch  0
Train      | MSE: 13.98202
Validation | MSE: 11.01190


Iteration: 100%|██████████| 100/100 [00:06<00:00, 15.67it/s]
Iter: 100%|██████████| 920/920 [02:16<00:00,  6.74it/s]


=====Epoch  1
Train      | MSE: 11.09647
Validation | MSE: 9.97083


Iteration: 100%|██████████| 100/100 [00:06<00:00, 15.81it/s]
Iter: 100%|██████████| 920/920 [02:24<00:00,  6.37it/s]


=====Epoch  2
Train      | MSE: 8.56703
Validation | MSE: 8.11504


Iteration: 100%|██████████| 100/100 [00:06<00:00, 15.24it/s]
Iter: 100%|██████████| 920/920 [02:33<00:00,  5.99it/s]


=====Epoch  3
Train      | MSE: 6.49640
Validation | MSE: 7.55463


Iteration: 100%|██████████| 100/100 [00:06<00:00, 15.81it/s]
Iter: 100%|██████████| 920/920 [02:23<00:00,  6.42it/s]


=====Epoch  4
Train      | MSE: 6.88781
Validation | MSE: 7.54679


Iteration: 100%|██████████| 100/100 [00:06<00:00, 16.33it/s]
Iter: 100%|██████████| 920/920 [02:20<00:00,  6.56it/s]


=====Epoch  5
Train      | MSE: 7.12907
Validation | MSE: 7.46634


Iteration: 100%|██████████| 100/100 [00:06<00:00, 15.67it/s]
Iter: 100%|██████████| 920/920 [02:20<00:00,  6.57it/s]


=====Epoch  6
Train      | MSE: 6.83407
Validation | MSE: 10.60364


Iteration: 100%|██████████| 100/100 [00:06<00:00, 15.90it/s]
Iter: 100%|██████████| 920/920 [02:22<00:00,  6.48it/s]


=====Epoch  7
Train      | MSE: 6.86842
Validation | MSE: 7.34576


Iteration: 100%|██████████| 100/100 [00:06<00:00, 15.91it/s]
Iter: 100%|██████████| 920/920 [02:20<00:00,  6.53it/s]


=====Epoch  8
Train      | MSE: 6.24524
Validation | MSE: 7.07446


Iteration: 100%|██████████| 100/100 [00:06<00:00, 15.86it/s]
Iter: 100%|██████████| 920/920 [02:14<00:00,  6.85it/s]


=====Epoch  9
Train      | MSE: 6.77860
Validation | MSE: 7.24841


Iteration: 100%|██████████| 100/100 [00:06<00:00, 16.56it/s]
Iter: 100%|██████████| 920/920 [02:09<00:00,  7.10it/s]


=====Epoch  10
Train      | MSE: 6.18162
Validation | MSE: 6.95470


Iteration: 100%|██████████| 100/100 [00:06<00:00, 16.57it/s]
Iter: 100%|██████████| 920/920 [02:09<00:00,  7.08it/s]


=====Epoch  11
Train      | MSE: 6.43730
Validation | MSE: 6.98433


Iteration: 100%|██████████| 100/100 [00:06<00:00, 16.45it/s]
Iter: 100%|██████████| 920/920 [02:08<00:00,  7.15it/s]


=====Epoch  12
Train      | MSE: 5.98261
Validation | MSE: 7.09261


Iteration: 100%|██████████| 100/100 [00:06<00:00, 16.22it/s]
Iter: 100%|██████████| 920/920 [02:08<00:00,  7.14it/s]


=====Epoch  13
Train      | MSE: 6.42925
Validation | MSE: 7.40595


Iteration: 100%|██████████| 100/100 [00:06<00:00, 16.47it/s]
Iter: 100%|██████████| 920/920 [02:05<00:00,  7.31it/s]


=====Epoch  14
Train      | MSE: 6.39245
Validation | MSE: 7.20384


Iteration: 100%|██████████| 100/100 [00:06<00:00, 16.56it/s]
Iter: 100%|██████████| 920/920 [02:07<00:00,  7.23it/s]


=====Epoch  15
Train      | MSE: 6.58797
Validation | MSE: 7.04990


Iteration: 100%|██████████| 100/100 [00:06<00:00, 16.19it/s]
Iter: 100%|██████████| 920/920 [02:19<00:00,  6.62it/s]


=====Epoch  16
Train      | MSE: 6.48583
Validation | MSE: 7.16788


Iteration: 100%|██████████| 100/100 [00:07<00:00, 13.69it/s]
Iter: 100%|██████████| 920/920 [02:25<00:00,  6.34it/s]


=====Epoch  17
Train      | MSE: 6.09398
Validation | MSE: 7.05967


Iteration: 100%|██████████| 100/100 [00:06<00:00, 15.39it/s]
Iter: 100%|██████████| 920/920 [02:22<00:00,  6.47it/s]


=====Epoch  18
Train      | MSE: 6.40673
Validation | MSE: 7.07927


Iteration: 100%|██████████| 100/100 [00:06<00:00, 16.17it/s]
Iter: 100%|██████████| 920/920 [02:18<00:00,  6.66it/s]


=====Epoch  19
Train      | MSE: 6.35017
Validation | MSE: 7.07385


In [34]:
for step, data in enumerate(train_loader):
    cell, drug, ic50 = data
    print(f'Step {step + 1}:')
    print(f'=======')
    print(f'Number of graphs in the current batch: {cell.num_graphs}')
    print(cell)
    print(ic50)
    print()

    if step == 1:
        break

Step 1:
Number of graphs in the current batch: 10
DataBatch(x=[8580, 4], edge_index=[2, 831260], batch=[8580], ptr=[11])
tensor([ 4.5438,  4.4568, -0.9637, -2.6823, -3.9976,  3.0180, -2.8586,  4.7727,
        -2.9860, -1.4190], dtype=torch.float64)

Step 2:
Number of graphs in the current batch: 10
DataBatch(x=[8580, 4], edge_index=[2, 831260], batch=[8580], ptr=[11])
tensor([ 3.1815,  2.9466,  0.6865,  4.0391,  0.9185,  4.0514,  2.4978,  6.8512,
        -3.0862,  4.6164], dtype=torch.float64)



In [304]:
# 858000
cell.x[858000]

IndexError: index 858000 is out of bounds for dimension 0 with size 858000

In [305]:
cell.x[858000-1]

tensor([ 4.2288, -1.0000,  2.0000,  0.0000], dtype=torch.float64)