## Test runner

### Import libraries

In [3]:
import os, sys
from typing import List, Tuple
from collections.abc import Callable
import time
import datetime as dt
from tqdm.notebook import tqdm

In [4]:
import pandas as pd
import numpy as np
import networkx as nx

In [5]:
from scipy.stats import wasserstein_distance

In [6]:
from sklearn.preprocessing import OneHotEncoder, StandardScaler, MultiLabelBinarizer
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer, make_column_selector
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix

In [7]:
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader, DenseDataLoader

from torch_geometric.nn import GraphConv, global_add_pool, DenseGraphConv, dense_diff_pool
import torch.nn.functional as F
from torch.nn import NLLLoss

from torch_geometric.utils import to_dense_adj, to_networkx
from torch_geometric.transforms import ToDense

In [8]:
import matplotlib.pyplot as plt
from matplotlib import rcParams
rcParams['figure.figsize'] = 15, 8.27

import seaborn as sns
import plotly.express as px
import plotly.io as pio
pio.templates.default = 'seaborn'

In [9]:
from ipywidgets import interact, interact_manual, FloatSlider

In [19]:
sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../src/'))

from src.utils import load_dataset, fetch_data, preprocess, create_dataset
from src.models import DiffPool, BaselineGNN
from src.train import train
from src.metrics import evaluate, TrainingMetrics, TestingMetrics

CONNECTION_DIR = '/Users/arnauddhaene/Downloads/'

### Fetch data

In [96]:
import torch
import torch.nn.functional as F
from torch.nn import Linear, LogSoftmax, ModuleList, Sequential, BatchNorm1d, ReLU

from torch_geometric.nn import GATv2Conv as GATConv, GraphConv, GINConv
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp

from models.custom import SparseModule


class BaselineGNN(SparseModule):
    def __init__(
        self,
        num_classes: int,
        hidden_dim: int,
        node_features_dim: int,
        layer_type: str = 'GraphConv',
        num_layers: int = 5
    ):
        super(BaselineGNN, self).__init__()
        self.layer_type = layer_type
        self.num_layers = num_layers
        
        self.convs = ModuleList()

        feature_extractor = \
            self.create_layer(in_channels=node_features_dim, out_channels=hidden_dim)

        self.convs.append(feature_extractor)
        
        for step in range(num_layers - 1):
            layer = \
                self.create_layer(in_channels=hidden_dim, out_channels=hidden_dim)
            
            self.convs.append(layer)

        self.fc1 = Linear(hidden_dim * 2 + 26, hidden_dim)
        self.fc2 = Linear(hidden_dim, num_classes)

        self.readout = LogSoftmax(dim=-1)

    def forward(self, data):
        
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        x, graph_x = x[:, :18], x[:, 18:]
        
        # This creates an array of indexes when the batch changes
        # If batch is [0, 0, 1, 1, 1, 1, 2, 2, 2] you will get [0, 2, 6]
        new_batch_index = np.concatenate([np.array([0]), np.where(batch[1:] != batch[:-1])[0] + 1], axis=0)
        
        for step in range(len(self.convs)):
            x = F.relu(self.convs[step](x, edge_index))
        
        x = torch.cat([gmp(x, batch), gap(x, batch), graph_x[new_batch_index, :]], dim=1)
        
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc2(x)
        
        return self.readout(x)
    
    def create_layer(self, **kwargs):
        """Create layer based on type

        Args:
            type (str): layer type

        Raises:
            ValueError: if type is not accepted within framework

        Returns:
            nn.Module: layer
        """

        if self.layer_type == 'GraphConv':
            return GraphConv(**kwargs)
        elif self.layer_type == 'GAT':
            return GATConv(**kwargs)
        elif self.layer_type == 'GIN':
            node_features, dim = kwargs['in_channels'], kwargs['out_channels']
            return GINConv(Sequential(
                Linear(node_features, dim), BatchNorm1d(dim), ReLU(),
                Linear(dim, dim), ReLU()))
        else:
            raise ValueError(f'{self.layer_type} is not a valid layer type')
    
    def __str__(self) -> str:
        """Representation"""
        return f'Baseline GNN with {self.num_layers} {self.layer_type} layers'


In [110]:
Data(x=dataset_train[0].x[:, :18], edge_index=dataset_train[0].edge_index, graph_x=dataset_train[0].x[0, 18:])

Data(x=[60, 18], edge_index=[2, 1690], graph_x=[26])

In [80]:
batch = np.array([0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3])

iddd = np.concatenate([np.array([0]), np.where(batch[1:] != batch[:-1])[0] + 1], axis=0)

In [46]:
test_size = 0.2
seed = 27
verbose = 0
connectivity = 'wasserstein'
distance = 0.8
dense=False

labels, lesions, patients = fetch_data(verbose)
        
X_train, X_test, y_train, y_test = \
    preprocess(labels, lesions, patients,
               test_size=test_size, seed=seed, verbose=verbose)
    
dataset_train = create_dataset(X=X_train, Y=y_train, dense=dense, distance=distance,
                               connectivity=connectivity, verbose=verbose)

dataset_test = create_dataset(X=X_test, Y=y_test, dense=dense, distance=distance,
                              connectivity=connectivity, verbose=verbose)

In [101]:
loader_train_args = dict(dataset=dataset_train[:round(len(dataset_train) * .8)], batch_size=4)
loader_valid_args = dict(dataset=dataset_train[round(len(dataset_train) * .8):], batch_size=4)
        
loader_train = DenseDataLoader(**loader_train_args) if model.is_dense() \
    else DataLoader(**loader_train_args)
loader_valid = DenseDataLoader(**loader_valid_args) if model.is_dense() \
    else DataLoader(**loader_valid_args)

In [135]:
model_args = dict(
    num_classes=2,
    hidden_dim=16,
    node_features_dim=18)

model = BaselineGNN(layer_type='GAT', **model_args)

model.reset()

In [136]:
criterion = NLLLoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=0.01, weight_decay=0.01)

device = None
dense = model.is_dense()

In [137]:
metrics = TrainingMetrics()

In [138]:
for epoch in tqdm(range(50)):
    epoch_loss = 0.

    model.train()
    
    for data in loader_train:
        data.to(device)
        
        output = model(data)
        
        loss = criterion(output, data.y.flatten())

        optimizer.zero_grad()
        
        epoch_loss += loss.item()
        loss.backward()
        
        optimizer.step()
        
    acc_train, _ = evaluate(model, loader_train)
    acc_valid, loss_valid = evaluate(model, loader_valid, validation=True)

    
    if epoch % 10 == 0:
        print(f'Epoch: {epoch:03d}, Train: {acc_train:.3f}, Val: {acc_valid:.3f}, Loss: {epoch_loss:.3f}, Val. loss: {loss_valid:.3f}')
        
    with torch.no_grad():
        metrics.log_metric('Loss - training', epoch_loss, step=epoch)
        metrics.log_metric('Loss - validation', loss_valid, step=epoch)
        metrics.log_metric('Accuracy - training', acc_train, step=epoch)
        metrics.log_metric('Accuracy - validation', acc_valid, step=epoch)

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch: 000, Train: 0.600, Val: 0.692, Loss: 8.568, Val. loss: 0.186
Epoch: 010, Train: 0.800, Val: 0.692, Loss: 7.419, Val. loss: 0.226
Epoch: 020, Train: 0.860, Val: 0.615, Loss: 4.993, Val. loss: 0.240
Epoch: 030, Train: 0.860, Val: 0.692, Loss: 5.106, Val. loss: 0.259
Epoch: 040, Train: 0.940, Val: 0.538, Loss: 3.898, Val. loss: 0.287


In [139]:
training_metrics = pd.DataFrame(metrics.storage) 

In [140]:
px.line(training_metrics, x='step', y='value', color='metric')

In [141]:
loader_test_args = dict(dataset=dataset_test, batch_size=len(dataset_test))

loader_test = DenseDataLoader(**loader_test_args) if model.is_dense() \
    else DataLoader(**loader_test_args)

test_metrics = TestingMetrics(epoch=75)
test_metrics.compute_metrics(model, loader_test)
pd.DataFrame(test_metrics.storage)

Unnamed: 0,metric,value
0,Accuracy - testing,0.444444
1,ROC AUC - testing,0.425
2,Precision - testing,0.333333
3,Recall - testing,0.25
4,Fscore - testing,0.285714
