## Test runner

### Import libraries

In [1]:
import os
import sys
from typing import List, Tuple
from collections.abc import Callable
import time
import datetime as dt
from tqdm.notebook import tqdm

In [2]:
import pandas as pd
import numpy as np
import networkx as nx

In [3]:
from scipy.stats import wasserstein_distance

In [4]:
from sklearn.preprocessing import OneHotEncoder, StandardScaler, MultiLabelBinarizer
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer, make_column_selector
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix

In [5]:
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader, DenseDataLoader

from torch_geometric.nn import GraphConv, global_add_pool, DenseGraphConv, dense_diff_pool
import torch.nn.functional as F
from torch.nn import NLLLoss

from torch_geometric.utils import to_dense_adj, to_networkx
from torch_geometric.transforms import ToDense

In [6]:
import matplotlib.pyplot as plt
from matplotlib import rcParams
rcParams['figure.figsize'] = 15, 8.27

import seaborn as sns
import plotly.express as px
import plotly.io as pio
pio.templates.default = 'seaborn'

In [7]:
from ipywidgets import interact, interact_manual, FloatSlider

In [8]:
sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../src/'))

from src.utils import load_dataset, fetch_data, preprocess, create_dataset, \
                      DATA_FOLDERS, FILES, standardise_column_names
from src.models.baseline import TimeGNN
from src.train import train
from src.metrics import evaluate, TrainingMetrics, TestingMetrics

CONNECTION_DIR = '/Users/adhaene/Downloads/'

### Dataset creation

### Placing data into DataLoaders

In [9]:
test_size = 0.0
seed = 12
verbose = 1
connectivity = 'wasserstein'
distance = 5.0
suspicious = 0.75

In [10]:
# labels, lesions, patients = fetch_data(suspicious=suspicious, verbose=verbose)

In [11]:
# len(labels), len(lesions.reset_index().gpcr_id.unique()), len(patients.reset_index().gpcr_id.unique())

In [12]:
# X_train, X_test, y_train, y_test = \
#     preprocess(labels, lesions, patients,
#                test_size=test_size, seed=12, verbose=verbose)

In [13]:
dataset_train, _ = \
        load_dataset(connectivity='wasserstein', seed=seed, test_size=test_size,
                     suspicious=suspicious, distance=distance, verbose=verbose)

Post-1 study lesions extracted for 94 patients
Post-2 study labels added for 58 patients
The intersection of datasets showed 56 potential datapoints.
Final dataset split -> Train: 53 | Test: 0


In [14]:
loader_train_args = dict(dataset=dataset_train[:round(len(dataset_train) * .8)], batch_size=1)
loader_valid_args = dict(dataset=dataset_train[round(len(dataset_train) * .8):], batch_size=1)
        
loader_train = DataLoader(**loader_train_args)
loader_valid = DataLoader(**loader_valid_args)

len(loader_train), len(loader_valid)

(42, 11)

## Modeling

### Model

In [57]:
import torch
# from torch import nn
import torch.nn.functional as F
from torch.nn import Linear, LogSoftmax, ModuleList, Sequential, BatchNorm1d, ReLU, LSTM

from torch_geometric.nn import GATv2Conv as GATConv, GraphConv, GINConv
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp

from models.custom import SparseModule


class BaselineGNN(SparseModule):
    def __init__(self, lesion_features_dim: int, hidden_dim: int,
                 layer_type: str = 'GraphConv', num_layers: int = 10):
        super(BaselineGNN, self).__init__()
        
        self.layer_type = layer_type
        self.num_layers = num_layers
        
        self.convs = ModuleList()

        feature_extractor = \
            self.create_layer(in_channels=lesion_features_dim, out_channels=hidden_dim)

        self.convs.append(feature_extractor)
        
        for step in range(num_layers - 1):
            layer = \
                self.create_layer(in_channels=hidden_dim, out_channels=hidden_dim)
            
            self.convs.append(layer)
    
    def forward(self, x: torch.Tensor, edge_index: torch.Tensor,
                edge_weight: torch.Tensor = None) -> torch.Tensor:

        # Wasserstein edge weights are added if GraphConv layers are used
        if self.layer_type == 'GNN':
            conv_kwargs = dict(edge_weight=edge_weight)
        # if self.layer_type == 'GAT':
            # conv_kwargs = dict(edge_attr=edge_weight)
        else:
            conv_kwargs = dict()
        
        for step in range(len(self.convs)):
            x = F.relu(self.convs[step](x, edge_index, **conv_kwargs))
        
        return x
        
    def create_layer(self, **kwargs):
        """Create layer based on type

        Args:
            type (str): layer type

        Raises:
            ValueError: if type is not accepted within framework

        Returns:
            nn.Module: layer
        """

        if self.layer_type == 'GraphConv':
            return GraphConv(**kwargs)
        elif self.layer_type == 'GAT':
            # return GATConv(heads=3, edge_dim=1, **kwargs)
            return GATConv(heads=1, **kwargs)
        elif self.layer_type == 'GIN':
            node_features, dim = kwargs['in_channels'], kwargs['out_channels']
            return GINConv(Sequential(
                Linear(node_features, dim), BatchNorm1d(dim), ReLU(),
                Linear(dim, dim), ReLU()))
        else:
            raise ValueError(f'{self.layer_type} is not a valid layer type')
        

class TimeGNN(SparseModule):
    def __init__(
        self,
        num_classes: int,
        hidden_dim: int,
        lesion_features_dim: int,
        study_features_dim: int,
        patient_features_dim: int,
        layer_type: str = 'GraphConv',
        num_layers: int = 10,
    ):
        super(TimeGNN, self).__init__()
        
        self.layer_type = layer_type
        self.num_layers = num_layers
        
        self.gnn = BaselineGNN(lesion_features_dim=lesion_features_dim, hidden_dim=hidden_dim,
                               layer_type=layer_type, num_layers=num_layers)

        self.rnn = LSTM(input_size=(hidden_dim * 2 + study_features_dim),
                        hidden_size=hidden_dim).to(dtype=torch.float64)
        
        self.fc1 = Linear(hidden_dim + patient_features_dim, hidden_dim).to(dtype=torch.float64)
        self.fc2 = Linear(hidden_dim, num_classes).to(dtype=torch.float64)

        self.readout = LogSoftmax(dim=-1).to(dtype=torch.float64)

    def forward(self, data):
        
        study_features, patient_features = data.study_features, data.patient_features
        
        xes = list(data.x.split(tuple(data.graph_sizes)))
        edge_indices = list(data.edge_index.split(tuple(data.split_sizes), dim=1))
        batches = list(data.batch.split(tuple(data.graph_sizes)))
        
        study_embeddings = []
        
        for i, (x, edge_index, batch) in enumerate(zip(xes, edge_indices, batches)):
            
            pooled_lesions_features = [gmp(self.gnn(x, edge_index), batch), gap(self.gnn(x, edge_index), batch)]
                
            # Size of which will be (len(xes), hidden_dim * 3)
            study_embeddings.append(torch.cat([
                *pooled_lesions_features, study_features[i, :].reshape(1, -1)
            ], dim=1).t())
        
        # len(xes) here is the sequence length
        rnn_output, _ = self.rnn(torch.cat(study_embeddings, dim=1).t().reshape(len(xes), 1, -1))
        study_pooled = rnn_output[-1, :, :].flatten()
        
        patient_pooled = torch.cat([study_pooled, patient_features], dim=0)

        patient_pooled = F.relu(self.fc1(patient_pooled)).to(dtype=torch.float64)
        patient_pooled = F.dropout(patient_pooled, p=0.25, training=self.training).to(dtype=torch.float64)
        patient_pooled = self.fc2(patient_pooled)

        return self.readout(patient_pooled).reshape(1, -1)
    
    def __str__(self) -> str:
        """Representation"""
        return f'TimeGNN with {self.num_layers} {self.layer_type} layers'

### Experimentation

In [33]:
lesion_features_dim = loader_train.dataset[0].x.shape[1]
study_features_dim = loader_train.dataset[0].study_features.shape[1]
patient_features_dim = loader_train.dataset[0].patient_features.shape[0]
hidden_dim = 64
layer_type = "GAT"
num_layers = 5
num_classes = 2

gnn = BaselineGNN(lesion_features_dim=lesion_features_dim, hidden_dim=hidden_dim,
                               layer_type=layer_type, num_layers=num_layers)

fc1 = Linear(hidden_dim * 2, hidden_dim) \
    .to(dtype=torch.float64)
fc2 = Linear(hidden_dim, num_classes).to(dtype=torch.float64)

sl1 = Linear(study_features_dim, hidden_dim).to(dtype=torch.float64)
sl2 = Linear(hidden_dim, hidden_dim).to(dtype=torch.float64)

pl1 = Linear(patient_features_dim, hidden_dim).to(dtype=torch.float64)
pl2 = Linear(hidden_dim, hidden_dim).to(dtype=torch.float64)

readout = LogSoftmax(dim=-1).to(dtype=torch.float64)

rnn = torch.nn.RNN(input_size=(hidden_dim * 3), hidden_size=hidden_dim, num_layers=1, nonlinearity='relu', batch_first=False, dropout=0.).to(dtype=torch.float64)

### Instantiate model

In [58]:
model_args = dict(
    num_classes=2,
    hidden_dim=64,
    lesion_features_dim=loader_train.dataset[0].x.shape[1],
    study_features_dim=loader_train.dataset[0].study_features.shape[1],
    patient_features_dim=loader_train.dataset[0].patient_features.shape[0],
    num_layers=5, layer_type='GAT')

model = TimeGNN(**model_args)

model.reset()

criterion = NLLLoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=0.01, weight_decay=0.01)

device = None

In [59]:
model.param_count()

97602

## Train model

In [60]:
metrics = TrainingMetrics()

for epoch in tqdm(range(100)):
    epoch_loss = 0.

    model.train()
    
    for data in loader_train:
        data.to(device)
        
        optimizer.zero_grad()

        output = model(data)
        loss = criterion(output, data.y.flatten())
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        
    with torch.no_grad():
        acc_train, _ = evaluate(model, loader_train)
        acc_valid, loss_valid = evaluate(model, loader_valid, validation=True)

        if epoch % 5 == 0:
            print(f'Epoch: {epoch:03d}, Train: {acc_train:.3f}, Val: {acc_valid:.3f}, Loss: {epoch_loss:.3f}, Val. loss: {loss_valid:.3f}')
        
        metrics.log_metric('Loss - training', epoch_loss, step=epoch)
        metrics.log_metric('Loss - validation', loss_valid, step=epoch)
        metrics.log_metric('Accuracy - training', acc_train, step=epoch)
        metrics.log_metric('Accuracy - validation', acc_valid, step=epoch)

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch: 000, Train: 0.619, Val: 0.545, Loss: 29.064, Val. loss: 0.762
Epoch: 005, Train: 0.810, Val: 0.545, Loss: 24.371, Val. loss: 0.702
Epoch: 010, Train: 0.929, Val: 0.636, Loss: 12.643, Val. loss: 0.729
Epoch: 015, Train: 1.000, Val: 0.455, Loss: 9.786, Val. loss: 1.054
Epoch: 020, Train: 1.000, Val: 0.545, Loss: 5.288, Val. loss: 1.184
Epoch: 025, Train: 1.000, Val: 0.545, Loss: 2.804, Val. loss: 1.154
Epoch: 030, Train: 1.000, Val: 0.545, Loss: 2.017, Val. loss: 1.148
Epoch: 035, Train: 1.000, Val: 0.455, Loss: 2.063, Val. loss: 1.260


KeyboardInterrupt: 

In [88]:
model

TimeGNN(
  (gnn): BaselineGNN(
    (convs): ModuleList(
      (0): GATv2Conv(28, 64, heads=1)
      (1): GATv2Conv(64, 64, heads=1)
      (2): GATv2Conv(64, 64, heads=1)
      (3): GATv2Conv(64, 64, heads=1)
      (4): GATv2Conv(64, 64, heads=1)
    )
  )
  (rnn): LSTM(151, 64)
  (fc_readout_1): Linear(in_features=66, out_features=64, bias=True)
  (fc_readout_2): Linear(in_features=64, out_features=2, bias=True)
  (readout): LogSoftmax(dim=-1)
)

In [75]:
pd.Series(list(map(lambda d: d.graph_sizes.shape[0], loader_train.dataset))).value_counts()

1    28
2    14
dtype: int64

 ## Evaluate training metrics

In [76]:
training_metrics = pd.DataFrame(metrics.storage)

In [77]:
px.line(training_metrics, x='step', y='value', color='metric')

## Evaluate testing metrics

In [None]:
loader_test_args = dict(dataset=dataset_test, batch_size=len(dataset_test))

loader_test = DataLoader(**loader_test_args)

test_metrics = TestingMetrics(epoch=200)
test_metrics.compute_metrics([model, model], loader_test)
pd.DataFrame(test_metrics.storage)

RuntimeError: torch.cat(): Sizes of tensors must match except in dimension 1. Got 2 and 1 in dimension 0 (The offending index is 2)