## Test runner

### Import libraries

In [1]:
import os, sys
from typing import List, Tuple
from collections.abc import Callable
import time
import datetime as dt
from tqdm.notebook import tqdm

In [2]:
import pandas as pd
import numpy as np
import networkx as nx

In [3]:
from scipy.stats import wasserstein_distance

In [4]:
from sklearn.preprocessing import OneHotEncoder, StandardScaler, MultiLabelBinarizer
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer, make_column_selector
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix

In [5]:
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader, DenseDataLoader

from torch_geometric.nn import GraphConv, global_add_pool, DenseGraphConv, dense_diff_pool
import torch.nn.functional as F
from torch.nn import NLLLoss

from torch_geometric.utils import to_dense_adj, to_networkx
from torch_geometric.transforms import ToDense

In [6]:
import matplotlib.pyplot as plt
from matplotlib import rcParams
rcParams['figure.figsize'] = 15, 8.27

import seaborn as sns
import plotly.express as px
import plotly.io as pio
pio.templates.default = 'seaborn'

In [7]:
from ipywidgets import interact, interact_manual, FloatSlider

In [8]:
sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../src/'))

from src.utils import load_dataset, fetch_data, preprocess, create_dataset, \
                      DATA_FOLDERS, FILES, standardise_column_names
from src.models import DiffPool, BaselineGNN
from src.train import train
from src.metrics import evaluate, TrainingMetrics, TestingMetrics

CONNECTION_DIR = '/Users/adhaene/Downloads/'

### Fetch data

How many patients do we have both pre-01, post-01, and post-02 for?

In [9]:
# sum(
#     shape[shape.study_name.isin(['pre-01', 'post-01', 'post-02'])].groupby(['gpcr_id']).study_name.unique()
#         .apply(len).to_numpy() > 2
# )

### Dataset creation

### Placing data into DataLoaders

In [10]:
test_size = 0.2
seed = 42
verbose = 0
connectivity = 'wasserstein'
distance = 0.8
dense=True

In [11]:
labels, lesions, patients = fetch_data(verbose)

In [13]:
X_train, X_test, y_train, y_test = \
    preprocess(labels, lesions, patients,
               test_size=test_size, seed=seed, verbose=verbose)



In [10]:
dataset_train = create_dataset(X=X_train, Y=y_train, dense=dense, distance=distance,
                               connectivity=connectivity, verbose=verbose)

dataset_test = create_dataset(X=X_test, Y=y_test, dense=dense, distance=distance,
                              connectivity=connectivity, verbose=verbose)

In [None]:
dataset_train

In [11]:
loader_train_args = dict(dataset=dataset_train[:round(len(dataset_train) * .8)], batch_size=4)
loader_valid_args = dict(dataset=dataset_train[round(len(dataset_train) * .8):], batch_size=4)
        
loader_train = DenseDataLoader(**loader_train_args) if dense \
    else DataLoader(**loader_train_args)
loader_valid = DenseDataLoader(**loader_valid_args) if dense \
    else DataLoader(**loader_valid_args)

## Modeling

### Model

In [1]:
import torch
import torch.nn.functional as F
from torch.nn import Linear, LogSoftmax, ModuleList, Sequential, BatchNorm1d, ReLU

from torch_geometric.nn import GATv2Conv as GATConv, GraphConv, GINConv
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp

from models.custom import SparseModule


class BaselineGNN(SparseModule):
    def __init__(
        self,
        num_classes: int,
        hidden_dim: int,
        node_features_dim: int,
        graph_features_dim: int,
        layer_type: str = 'GraphConv',
        num_layers: int = 10,
    ):
        super(BaselineGNN, self).__init__()
        self.layer_type = layer_type
        self.num_layers = num_layers
        
        self.convs = ModuleList()

        feature_extractor = \
            self.create_layer(in_channels=node_features_dim, out_channels=hidden_dim)

        self.convs.append(feature_extractor)
        
        for step in range(num_layers - 1):
            layer = \
                self.create_layer(in_channels=hidden_dim, out_channels=hidden_dim)
            
            self.convs.append(layer)

        self.fc1 = Linear(hidden_dim * 2 + graph_features_dim, hidden_dim)
        self.fc2 = Linear(hidden_dim, num_classes)

        self.readout = LogSoftmax(dim=-1)

    def forward(self, data):
        
        x, edge_index, batch, graph_features, edge_weight = \
            data.x, data.edge_index, data.batch, data.graph_features, data.edge_weight

        # Wasserstein edge weights are added if GraphConv layers are used
        if self.layer_type == 'GNN':
            conv_kwargs = dict(edge_weight=edge_weight)
        else:
            conv_kwargs = dict()
        
        for step in range(len(self.convs)):
            x = F.relu(self.convs[step](x, edge_index, **conv_kwargs))
        
        # Concatenate pooling from graph embeddings with graph features
        x = torch.cat(
            [gmp(x, batch), gap(x, batch), graph_features.reshape(batch.unique().size(0), -1)], dim=1)
        
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc2(x)
        
        return self.readout(x)
    
    def create_layer(self, **kwargs):
        """Create layer based on type

        Args:
            type (str): layer type

        Raises:
            ValueError: if type is not accepted within framework

        Returns:
            nn.Module: layer
        """

        if self.layer_type == 'GraphConv':
            return GraphConv(**kwargs)
        elif self.layer_type == 'GAT':
            return GATConv(**kwargs)
        elif self.layer_type == 'GIN':
            node_features, dim = kwargs['in_channels'], kwargs['out_channels']
            return GINConv(Sequential(
                Linear(node_features, dim), BatchNorm1d(dim), ReLU(),
                Linear(dim, dim), ReLU()))
        else:
            raise ValueError(f'{self.layer_type} is not a valid layer type')
    
    def __str__(self) -> str:
        """Representation"""
        return f'Baseline GNN with {self.num_layers} {self.layer_type} layers'

ModuleNotFoundError: No module named 'models'

### Instantiate model

In [54]:
model_args = dict(
    num_classes=2,
    hidden_dim=16,
    graph_features_dim=26,
    node_features_dim=18)

# model = BaselineGNN(layer_type='GAT', **model_args)

model = DiffPool(num_nodes=[9], **model_args)

model.reset()

criterion = NLLLoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=0.01, weight_decay=0.01)

device = None
dense = model.is_dense()

## Train model

In [55]:
metrics = TrainingMetrics()

for epoch in tqdm(range(50)):
    epoch_loss = 0.

    model.train()
    
    for data in loader_train:
        data.to(device)
        
        output = model(data)
        
        loss = criterion(output, data.y.flatten())

        optimizer.zero_grad()
        
        epoch_loss += loss.item()
        loss.backward()
        
        optimizer.step()
        
    acc_train, _ = evaluate(model, loader_train)
    acc_valid, loss_valid = evaluate(model, loader_valid, validation=True)

    
    if epoch % 10 == 0:
        print(f'Epoch: {epoch:03d}, Train: {acc_train:.3f}, Val: {acc_valid:.3f}, Loss: {epoch_loss:.3f}, Val. loss: {loss_valid:.3f}')
        
    with torch.no_grad():
        metrics.log_metric('Loss - training', epoch_loss, step=epoch)
        metrics.log_metric('Loss - validation', loss_valid, step=epoch)
        metrics.log_metric('Accuracy - training', acc_train, step=epoch)
        metrics.log_metric('Accuracy - validation', acc_valid, step=epoch)

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch: 000, Train: 0.580, Val: 0.692, Loss: 126886090900608.812, Val. loss: 0.168
Epoch: 010, Train: 0.740, Val: 0.462, Loss: 7.331, Val. loss: 0.214
Epoch: 020, Train: 0.780, Val: 0.538, Loss: 6.795, Val. loss: 0.223
Epoch: 030, Train: 0.800, Val: 0.538, Loss: 6.592, Val. loss: 0.227
Epoch: 040, Train: 0.800, Val: 0.538, Loss: 6.536, Val. loss: 0.226


 ## Evaluate training metrics

In [48]:
training_metrics = pd.DataFrame(metrics.storage)

Unnamed: 0,metric,value,step,run
0,Loss - training,7019288000000000.0,0,0
4,Loss - training,4974318000000000.0,1,0
8,Loss - training,2850838000000000.0,2,0
12,Loss - training,2.461774e+16,3,0
16,Loss - training,1452034000000000.0,4,0
20,Loss - training,2594106000000000.0,5,0
24,Loss - training,5328337000000000.0,6,0
28,Loss - training,1964872000000000.0,7,0
32,Loss - training,928468300000000.0,8,0
36,Loss - training,6782248000000000.0,9,0


In [44]:
px.line(training_metrics, x='step', y='value', color='metric')

## Evaluate testing metrics

In [141]:
loader_test_args = dict(dataset=dataset_test, batch_size=len(dataset_test))

loader_test = DenseDataLoader(**loader_test_args) if model.is_dense() \
    else DataLoader(**loader_test_args)

test_metrics = TestingMetrics(epoch=75)
test_metrics.compute_metrics(model, loader_test)
pd.DataFrame(test_metrics.storage)

Unnamed: 0,metric,value
0,Accuracy - testing,0.444444
1,ROC AUC - testing,0.425
2,Precision - testing,0.333333
3,Recall - testing,0.25
4,Fscore - testing,0.285714
