## Test runner

### Import libraries

In [1]:
import os, sys
from typing import List, Tuple
from collections.abc import Callable
import time
import datetime as dt
from tqdm.notebook import tqdm

In [2]:
import pandas as pd
import numpy as np
import networkx as nx

In [3]:
from scipy.stats import wasserstein_distance

In [4]:
from sklearn.preprocessing import OneHotEncoder, StandardScaler, MultiLabelBinarizer
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer, make_column_selector
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix

In [5]:
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader, DenseDataLoader

from torch_geometric.nn import GraphConv, global_add_pool, DenseGraphConv, dense_diff_pool
import torch.nn.functional as F
from torch.nn import NLLLoss

from torch_geometric.utils import to_dense_adj, to_networkx
from torch_geometric.transforms import ToDense

In [6]:
import matplotlib.pyplot as plt
from matplotlib import rcParams
rcParams['figure.figsize'] = 15, 8.27

import seaborn as sns
import plotly.express as px
import plotly.io as pio
pio.templates.default = 'seaborn'

In [7]:
from ipywidgets import interact, interact_manual, FloatSlider

In [8]:
sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../src/'))

from src.utils import load_dataset, fetch_data, preprocess, create_dataset, \
                      DATA_FOLDERS, FILES, standardise_column_names
from src.models import DiffPool, BaselineGNN
from src.train import train
from src.metrics import evaluate, TrainingMetrics, TestingMetrics

CONNECTION_DIR = '/Users/adhaene/Downloads/'

### Fetch data

How many patients do we have both pre-01, post-01, and post-02 for?

In [9]:
# sum(
#     shape[shape.study_name.isin(['pre-01', 'post-01', 'post-02'])].groupby(['gpcr_id']).study_name.unique()
#         .apply(len).to_numpy() > 2
# )

### Dataset creation

### Placing data into DataLoaders

In [10]:
test_size = 0.2
seed = 42
verbose = 0
connectivity = 'wasserstein'
distance = 3.0
dense=False

In [21]:
labels, lesions, patients = fetch_data(suspicious=0.5, verbose=verbose)

In [27]:
len(labels), len(lesions.reset_index().gpcr_id.unique()), len(patients.reset_index().gpcr_id.unique())

(63, 66, 63)

In [22]:
X_train, X_test, y_train, y_test = \
    preprocess(labels, lesions, patients,
               test_size=0., seed=12, verbose=verbose)

In [23]:
len(X_train.reset_index().gpcr_id.unique())

61

In [13]:
dataset_train, dataset_test = \
        load_dataset(connectivity='wasserstein', seed=12, test_size=0.,
                     suspicious=0.5, distance=2.0, verbose=1)

Post-1 study lesions extracted for 106 patients
Post-2 study labels added for 66 patients
The intersection of datasets showed 63 potential datapoints.
Final dataset split -> Train: 61 | Test: 0


In [13]:
# edge_index.split(split_sizes, dim=1), edge_weight.split(split_sizes, dim=0)

In [14]:
dataset_train = create_dataset(X=X_train, Y=y_train, distance=distance,
                               connectivity=connectivity, verbose=verbose)

dataset_test = create_dataset(X=X_test, Y=y_test, distance=distance,
                              connectivity=connectivity, verbose=verbose)

In [28]:
from torch.utils.data import random_split
from torch.utils.data import SubsetRandomSampler
from torch_geometric.loader import DataLoader

train_length = len(dataset_train)
lengths = [round(train_length * 0.8), train_length - round(train_length * 0.8)]

split = random_split(range(len(dataset_train)), lengths, generator=torch.Generator().manual_seed(42))

In [25]:
from sklearn.model_selection import KFold

kfold = KFold(n_splits=5, shuffle=True)

next(iter(kfold.split(dataset_train)))

(array([ 2,  4,  5,  6,  7,  8,  9, 10, 11, 12, 15, 16, 17, 19, 20, 22, 23,
        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 38, 39, 41, 43,
        44, 45, 46, 47, 48, 49, 50, 51, 53, 55, 56, 57, 59, 61, 62, 63, 64,
        65, 67, 69, 71, 72, 73, 74, 75, 76, 77, 78, 79, 81, 82, 83, 84, 85]),
 array([ 0,  1,  3, 13, 14, 18, 21, 37, 40, 42, 52, 54, 58, 60, 66, 68, 70,
        80]))

In [32]:
len(split[0]), len(split[1])

(69, 17)

In [16]:
len(dataset_train)

86

In [33]:
len(DataLoader(dataset=dataset_train, batch_size=1, sampler=SubsetRandomSampler(split[0]))), \
    len(DataLoader(dataset=dataset_train, batch_size=1, sampler=SubsetRandomSampler(split[1])))

(69, 17)

In [34]:
loader = DataLoader(dataset=dataset_train, batch_size=1, sampler=SubsetRandomSampler(split[0]))

In [35]:
next(iter(loader))

Batch(x=[58, 28], edge_index=[2, 1386], y=[1], study_features=[58, 18], patient_features=[2], graph_sizes=[2], split_sizes=[2], edge_weight=[1386], num_nodes=58, batch=[58], ptr=[2])

In [147]:
loader_train_args = dict(dataset=dataset_train[:round(len(dataset_train) * .7)], batch_size=1)
loader_valid_args = dict(dataset=dataset_train[round(len(dataset_train) * .7):], batch_size=1)
        
loader_train = DenseDataLoader(**loader_train_args) if dense \
    else DataLoader(**loader_train_args)
loader_valid = DenseDataLoader(**loader_valid_args) if dense \
    else DataLoader(**loader_valid_args)

## Modeling

### Model

In [148]:
import torch
import torch.nn.functional as F
from torch.nn import Linear, LogSoftmax, ModuleList, Sequential, BatchNorm1d, ReLU

from torch_geometric.nn import GATv2Conv as GATConv, GraphConv, GINConv
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp

from models.custom import SparseModule


class BaselineGNN(SparseModule):
    def __init__(self, lesion_features_dim: int, hidden_dim: int, 
                 layer_type: str = 'GraphConv', num_layers: int = 10):
        super(BaselineGNN, self).__init__()
        
        self.layer_type = layer_type
        self.num_layers = num_layers
        
        self.convs = ModuleList()

        feature_extractor = \
            self.create_layer(in_channels=lesion_features_dim, out_channels=hidden_dim)

        self.convs.append(feature_extractor)
        
        for step in range(num_layers - 1):
            layer = \
                self.create_layer(in_channels=hidden_dim, out_channels=hidden_dim)
            
            self.convs.append(layer)
    
    def forward(self, x: torch.Tensor, edge_index: torch.Tensor,
                edge_weight: torch.Tensor = None) -> torch.Tensor:

        # Wasserstein edge weights are added if GraphConv layers are used
        if self.layer_type == 'GNN':
            conv_kwargs = dict(edge_weight=edge_weight)
        else:
            conv_kwargs = dict()
        
        for step in range(len(self.convs)):
            x = F.relu(self.convs[step](x, edge_index, **conv_kwargs))
        
        return x
        
        
    def create_layer(self, **kwargs):
        """Create layer based on type

        Args:
            type (str): layer type

        Raises:
            ValueError: if type is not accepted within framework

        Returns:
            nn.Module: layer
        """

        if self.layer_type == 'GraphConv':
            return GraphConv(**kwargs)
        elif self.layer_type == 'GAT':
            return GATConv(**kwargs)
        elif self.layer_type == 'GIN':
            node_features, dim = kwargs['in_channels'], kwargs['out_channels']
            return GINConv(Sequential(
                Linear(node_features, dim), BatchNorm1d(dim), ReLU(),
                Linear(dim, dim), ReLU()))
        else:
            raise ValueError(f'{self.layer_type} is not a valid layer type')

In [149]:
data = next(iter(loader_train))

In [150]:
gnn = BaselineGNN(lesion_features_dim=28, hidden_dim=16)

In [162]:
data.study_features.repeat(2, 1)

tensor([[ 0.2495, -0.6099, -0.2123, -1.0354, -0.5818, -1.3415, -0.4875,  1.0000,
          0.0000,  0.0000,  1.0000,  1.0000,  0.0000,  0.0000,  0.0000,  1.0000],
        [ 0.5462, -0.6099, -0.8766, -0.0798, -0.3491, -0.1645, -0.2000,  1.0000,
          0.0000,  0.0000,  1.0000,  1.0000,  0.0000,  0.0000,  0.0000,  1.0000],
        [ 0.2495, -0.6099, -0.2123, -1.0354, -0.5818, -1.3415, -0.4875,  1.0000,
          0.0000,  0.0000,  1.0000,  1.0000,  0.0000,  0.0000,  0.0000,  1.0000],
        [ 0.5462, -0.6099, -0.8766, -0.0798, -0.3491, -0.1645, -0.2000,  1.0000,
          0.0000,  0.0000,  1.0000,  1.0000,  0.0000,  0.0000,  0.0000,  1.0000]],
       dtype=torch.float64)

In [73]:
data

Batch(x=[6, 28], edge_index=[2, 6], y=[1], study_features=[2, 16], patient_features=[2], graph_sizes=[2], split_sizes=[2], edge_weight=[6], batch=[6], ptr=[2])

In [74]:
study_features, patient_features = data.study_features, data.patient_features

xes = list(data.x.split(tuple(data.graph_sizes)))
edge_indices = list(data.edge_index.split(tuple(data.split_sizes), dim=1))
batches = list(data.batch.split(tuple(data.graph_sizes)))

study_embeddings = []

for i, (x, edge_index, batch) in enumerate(zip(xes, edge_indices, batches)):
    study_embeddings.append(torch.cat([
        gmp(gnn(x, edge_index), batch), gap(gnn(x, edge_index), batch),
        study_features[i, :].reshape(1, -1)], dim=1).t())


In [78]:
torch.cat(study_embeddings, dim=1).reshape(1, -1, 2).shape

torch.Size([1, 48, 2])

In [98]:

study_pooled = torch.cat(study_embeddings, dim=1).reshape(-1, len(xes)).mean(dim=1)
study_pooled.shape

torch.Size([48])

In [99]:

patient_pooled = torch.cat([study_pooled.flatten(), patient_features], dim=0)
patient_pooled.shape

torch.Size([50])

In [87]:

patient_pooled = F.relu(fc1(patient_pooled)).to(dtype=torch.float64)
patient_pooled = F.dropout(patient_pooled, p=0.5, training=True).to(dtype=torch.float64)
patient_pooled = fc2(patient_pooled)
patient_pooled.shape

torch.Size([2])

In [40]:
conv = torch.nn.Conv1d(in_channels=48, out_channels=16, kernel_size=4, padding=1).to(dtype=torch.float64)

In [46]:
fc1 = Linear(16 + data.patient_features.shape[0], 16).to(dtype=torch.float64)
fc2 = Linear(16, 2).to(dtype=torch.float64)

In [88]:
LogSoftmax(dim=-1)(patient_pooled)

tensor([-0.6564, -0.7313], dtype=torch.float64, grad_fn=<LogSoftmaxBackward>)

In [101]:
from torch import nn

class TimeGNN(SparseModule):
    def __init__(
        self,
        num_classes: int,
        hidden_dim: int,
        lesion_features_dim: int,
        study_features_dim: int,
        patient_features_dim: int,
        layer_type: str = 'GraphConv',
        num_layers: int = 10,
    ):
        super(TimeGNN, self).__init__()
        
        self.gnn = BaselineGNN(lesion_features_dim=lesion_features_dim, hidden_dim=hidden_dim,
                               layer_type=layer_type, num_layers=num_layers)
        
        # self.conv = nn.Conv1d(in_channels=(hidden_dim * 2 + study_features_dim), out_channels=hidden_dim,
        #                       kernel_size=4, padding=1).to(dtype=torch.float64)

        self.fc1 = Linear(hidden_dim * 2 + study_features_dim + patient_features_dim, hidden_dim).to(dtype=torch.float64)
        self.fc2 = Linear(hidden_dim, num_classes).to(dtype=torch.float64)

        self.readout = LogSoftmax(dim=-1).to(dtype=torch.float64)

    def forward(self, data):
        
        study_features, patient_features = data.study_features, data.patient_features
        
        xes = list(data.x.split(tuple(data.graph_sizes)))
        edge_indices = list(data.edge_index.split(tuple(data.split_sizes), dim=1))
        batches = list(data.batch.split(tuple(data.graph_sizes)))
        
        study_embeddings = []
        
        for i, (x, edge_index, batch) in enumerate(zip(xes, edge_indices, batches)):
            study_embeddings.append(torch.cat([
                gmp(self.gnn(x, edge_index), batch), gap(self.gnn(x, edge_index), batch),
                study_features[i, :].reshape(1, -1)], dim=1).t())
        
        study_pooled = torch.cat(study_embeddings, dim=1).reshape(-1, len(xes)).mean(dim=1)
        
        patient_pooled = torch.cat([study_pooled.flatten(), patient_features], dim=0)
        
        patient_pooled = F.relu(self.fc1(patient_pooled)).to(dtype=torch.float64)
        patient_pooled = F.dropout(patient_pooled, p=0.5, training=self.training).to(dtype=torch.float64)
        patient_pooled = self.fc2(patient_pooled)
        
        return self.readout(patient_pooled).reshape(1, -1)
    
    def __str__(self) -> str:
        """Representation"""
        return f'Baseline GNN with {self.num_layers} {self.layer_type} layers'

### Instantiate model

In [127]:
model_args = dict(
    num_classes=2,
    hidden_dim=64,
    lesion_features_dim=loader_train.dataset[0].x.shape[1],
    study_features_dim=loader_train.dataset[0].study_features.shape[1],
    patient_features_dim=loader_train.dataset[0].patient_features.shape[0],
    num_layers=5)

# model = BaselineGNN(layer_type='GAT', **model_args)

model = TimeGNN(**model_args)

model.reset()

criterion = NLLLoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=0.001, weight_decay=0.01)

device = None
dense = model.is_dense()

In [128]:
model.param_count()

46210

## Train model

In [129]:
metrics = TrainingMetrics()

for epoch in tqdm(range(200)):
    epoch_loss = 0.

    model.train()
    
    for data in loader_train:
        data.to(device)
        
        output = model(data)
        
        loss = criterion(output, data.y.flatten())

        optimizer.zero_grad()
        
        epoch_loss += loss.item()
        loss.backward()
        
        optimizer.step()
        
    acc_train, _ = evaluate(model, loader_train)
    acc_valid, loss_valid = evaluate(model, loader_valid, validation=True)

    
    if epoch % 10 == 0:
        print(f'Epoch: {epoch:03d}, Train: {acc_train:.3f}, Val: {acc_valid:.3f}, Loss: {epoch_loss:.3f}, Val. loss: {loss_valid:.3f}')
        
    with torch.no_grad():
        metrics.log_metric('Loss - training', epoch_loss, step=epoch)
        metrics.log_metric('Loss - validation', loss_valid, step=epoch)
        metrics.log_metric('Accuracy - training', acc_train, step=epoch)
        metrics.log_metric('Accuracy - validation', acc_valid, step=epoch)

  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 000, Train: 0.690, Val: 0.692, Loss: 573.415, Val. loss: 1.881
Epoch: 010, Train: 0.897, Val: 0.538, Loss: 64.572, Val. loss: 53.927
Epoch: 020, Train: 0.862, Val: 0.538, Loss: 10.109, Val. loss: 3.818
Epoch: 030, Train: 1.000, Val: 0.615, Loss: 7.279, Val. loss: 205.207
Epoch: 040, Train: 0.690, Val: 0.462, Loss: 2659.334, Val. loss: 1166.282
Epoch: 050, Train: 0.862, Val: 0.462, Loss: 21.632, Val. loss: 307.849
Epoch: 060, Train: 0.931, Val: 0.462, Loss: 9.257, Val. loss: 323.130
Epoch: 070, Train: 0.897, Val: 0.615, Loss: 13.790, Val. loss: 438.445
Epoch: 080, Train: 0.931, Val: 0.615, Loss: 8.046, Val. loss: 119.201
Epoch: 090, Train: 1.000, Val: 0.615, Loss: 5.940, Val. loss: 3.788
Epoch: 100, Train: 0.966, Val: 0.615, Loss: 7.037, Val. loss: 124.940
Epoch: 110, Train: 0.966, Val: 0.538, Loss: 5.325, Val. loss: 147.924
Epoch: 120, Train: 1.000, Val: 0.615, Loss: 4.597, Val. loss: 160.087
Epoch: 130, Train: 1.000, Val: 0.615, Loss: 4.365, Val. loss: 166.396
Epoch: 140, Train

 ## Evaluate training metrics

In [131]:
training_metrics = pd.DataFrame(metrics.storage)

In [132]:
px.line(training_metrics, x='step', y='value', color='metric')

## Evaluate testing metrics

In [107]:
loader_test_args = dict(dataset=dataset_test, batch_size=len(dataset_test))

loader_test = DenseDataLoader(**loader_test_args) if model.is_dense() \
    else DataLoader(**loader_test_args)

test_metrics = TestingMetrics(epoch=75)
test_metrics.compute_metrics(model, loader_test)
pd.DataFrame(test_metrics.storage)

RuntimeError: torch.cat(): Sizes of tensors must match except in dimension 1. Got 2 and 1 in dimension 0 (The offending index is 2)