# Task 4 - Custom Message Passing

**Project Team 9**

In [None]:
import os
import random
import statistics
from tqdm import tqdm
import networkx as nx
import torch
#import torch_scatter
import numpy as np
import pandas as pd
from itertools import product


import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LSTM
import torch.nn.functional as F
from torch.utils.data import random_split

import torch_geometric
import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils
from torch_geometric.nn import GCNConv, SAGEConv


from torch import Tensor
from typing import Union, Tuple, Optional
from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType,
                                    OptTensor)
from torch_geometric.nn.pool import global_mean_pool, global_max_pool, global_add_pool
from torch_geometric.data import DataLoader
from torch_geometric.utils import to_networkx
from torch_geometric.utils.dropout import dropout_edge
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, to_dense_batch, sort_edge_index

from torch.nn import Parameter, Linear
#from torch_sparse import SparseTensor, set_diag
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax

from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
from pytorch_lightning import Trainer

import pytorch_lightning as pl


import matplotlib.pyplot as plt
import plotly.express as px
import seaborn as sns
from functools import partial
from sklearn.metrics import accuracy_score
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from torch_geometric.datasets import TUDataset

## Message Passing Layers

### Mean Aggregation

In [None]:
class MeanSAGE(MessagePassing):
    def __init__(self, in_channels, out_channels, normalize = True,
                 bias = False, **kwargs):
        super(MeanSAGE, self).__init__(aggr='add')
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.normalize = normalize

        self.lin_l = nn.Linear(in_channels,out_channels) #linear projection for the central node
        self.lin_r = nn.Linear(in_channels,out_channels) #linear projection for the neighbhour aggregation

        self.reset_parameters()
    def reset_parameters(self):

        self.lin_l.reset_parameters()
        self.lin_r.reset_parameters()

        pass
    def forward(self, x, edge_index, size = None):

        row, col = edge_index

        norm = torch_geometric.utils.degree(row, x.size(0), dtype=x.dtype).pow(-1) #normalization factor as the neighbour size

        agg_term = self.propagate(edge_index, x=x, norm=norm)

        x = self.lin_l(x)

        agg_term = self.lin_r(agg_term)

        out = x + agg_term

        if self.normalize:
            out = torch.nn.functional.normalize(out, p=2.0, dim=1, eps=1e-12, out=None)
        return out

    def message(self, x_j):
        out= x_j
        return out


### LSTM Aggregation

In [None]:
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, to_dense_batch, sort_edge_index


class LSTMSAGE(MessagePassing):
    def __init__(self, in_channels, out_channels, normalize = True,
                 bias = False, **kwargs):
        super(LSTMSAGE, self).__init__(aggr='add')

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.normalize = normalize


        self.lstm = torch.nn.LSTM(in_channels, out_channels,batch_first=False) #LSTM for aggregation
        self.lin_l = nn.Linear(in_channels,out_channels) #linear projection for the central node
        self.lin_r = nn.Linear(in_channels,out_channels) #linear projection for the neighbhour aggregation
        self.reset_parameters()



    def reset_parameters(self):
        self.lstm.reset_parameters()
        self.lin_l.reset_parameters()
        self.lin_r.reset_parameters()
        pass


    def forward(self, x, edge_index, size = None):


        #Firstly, sort add self loops so that each LSTM aggregation always has at least one input
        edge_index = add_self_loops(edge_index)[0]
        #Sort edge indexes by the source node
        edge_index = sort_edge_index(edge_index)

        agg_term = self.propagate(edge_index, x=x)

        out = agg_term

        if self.normalize:
            out = torch.nn.functional.normalize(out, p=2.0, dim=1, eps=1e-12, out=None)
        return out

    def message(self, x_j):
        out= x_j
        return out


    def aggregate(self, inputs, index, dim_size = None):
        dim_size = len(torch.unique(index)) #Get unique number of neighbours
        out = torch.zeros(dim_size, self.out_channels)
        x, _ = to_dense_batch(inputs, index,batch_size=dim_size) #Sort LSTM inputs by target node
        test, (hn,cn) = self.lstm(x)
        out = test[:,-1] #Get the latest hidden state from the LSTM output

        return out





### Attention Weighted LSTM Aggregation

In [None]:
class AttnLSTMSAGE(MessagePassing):
    def __init__(self, in_channels, out_channels, normalize = True):
        super(AttnLSTMSAGE, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.normalize = normalize


        self.lstm = torch.nn.LSTM(in_channels, out_channels,batch_first=False) #LSTM for aggregation
        self.att = pyg_nn.GATConv(in_channels=in_channels, out_channels=in_channels) #Attention for sorting edges
        self.reset_parameters()

    def reset_parameters(self):
        self.lstm.reset_parameters()
        self.attn.reset_parameters()
        pass

    def forward(self, x, edge_index):
        # Sort edges by source node (edge_attr is used for ordering)
        x, att_weights = self.att(x, edge_index, return_attention_weights=True) #Pass GATConv to get attention weights for sorting

        #Ordering the edges by attention weights
        edge_index = add_self_loops(edge_index)[0]
        ordering = torch.argsort(att_weights[1].T)[0]
        edge_index = edge_index[:, ordering]

        agg_term = self.propagate(edge_index, x=x)

        out = agg_term

        if self.normalize:
            out = torch.nn.functional.normalize(out, p=2.0, dim=1, eps=1e-12, out=None)
        return out


        return out

    def message(self, x_j):
        out= x_j
        return out


    def aggregate(self, inputs, index, dim_size = None):
        dim_size = len(torch.unique(index)) #Get unique number of neighbours
        out = torch.zeros(dim_size, self.out_channels)
        x, _ = to_dense_batch(inputs, index,batch_size=dim_size) #Sort LSTM inputs by target node
        test, (hn,cn) = self.lstm(x)
        out = test[:,-1] #Get the latest hidden state from the LSTM output

        return out



### Shortest Path Aggregation

In [None]:
class ShortestPathSAGE(MessagePassing):
    def __init__(self, in_channels, out_channels, normalize = True,
                 bias = False, **kwargs):
        super(ShortestPathSAGE, self).__init__(aggr='add')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize

        self.lin_r = nn.Linear(in_channels, out_channels, bias=bias)
        self.lin_l = nn.Linear(in_channels, out_channels, bias=bias)
        self.node_indices = torch.unique(edge_index)
        self.reset_parameters()

    def reset_parameters(self):
        self.lin_r.reset_parameters()
        self.lin_r.reset_parameters()
        pass

    def forward(self, x, edge_index, dist_mat = None, size = None):
        if dist_mat is None: #If no distance matrix is provided, compute with Floyd-Warshall
            num_nodes = torch.unique(edge_index).shape[0]
            # Initialize the distance matrix with infinity values.
            dist_matrix = np.full((num_nodes, num_nodes), float('inf'))

            for i in range(edge_index.shape[1]):
                start_node, end_node = edge_index[0, i], edge_index[1, i]
                dist_matrix[start_node, end_node] = 1  #Unit weight for each hop

            np.fill_diagonal(dist_matrix, 0) #Fill self hops with 0

            # Floyd-Warshall
            for k in range(num_nodes):
                for i in range(num_nodes):
                    for j in range(num_nodes):
                        dist_matrix[i, j] = min(dist_matrix[i, j], dist_matrix[i, k] + dist_matrix[k, j])#If no distance matrix is provided, compute it
        out = torch.zeros(x.shape[0], self.in_channels)
        self.node_indices = np.array(self.node_indices)

        #Calculate denominator for aggregation
        denominator = dist_matrix[self.node_indices[:, np.newaxis], self.node_indices]
        denominator = 1/torch.tensor(denominator,dtype=torch.float32)

        denominator.fill_diagonal_(0) #Fill inf diagonals with 0

        out = torch.matmul(denominator,x) #Aggregate with norm as the inverse distance matrix

        out = self.lin_l(out) #Linearly project aggregation
        out = out + self.lin_r(x) #Linearly project central node

        if self.normalize:
            out = torch.nn.functional.normalize(out, p=2.0, dim=1, eps=1e-12, out=None)
        return out




##### Shortest path Aggregation is prohibtively long to train on any non-trivial model. This is due to the complexity of Floyd Warshall. We suggest that such an aggregation function be used on either less connected graphs or graphs with fewer nodes.

## Assertion Test for Message Passing

In [None]:

node_embedings = torch.ones(4, 8) # A graph with 4 nodes and 8 dimensional node features
edge_index = torch.tensor([[0, 1, 2, 0, 3],
                            [1, 0, 1, 3, 2]], dtype=torch.long) # Example edge index
ordering, _ = torch.sort(edge_index[0])

#Mean Aggregation
custom_layer = MeanSAGE(in_channels=8, out_channels=2)
transformed_embeddings = custom_layer(node_embedings, edge_index)
assert transformed_embeddings.shape == torch.Size([4, 2]), "Incorrect shape returned."

#LSTM Aggregation
custom_layer = LSTMSAGE(in_channels=8, out_channels=2)
transformed_embeddings = custom_layer(node_embedings, edge_index)
assert transformed_embeddings.shape == torch.Size([4, 2]), "Incorrect shape returned."

#LSTM Attention Weighted Aggregation
custom_layer = AttnLSTMSAGE(in_channels=8, out_channels=2)
transformed_embeddings = custom_layer(node_embedings, edge_index)
assert transformed_embeddings.shape == torch.Size([4, 2]), "Incorrect shape returned."

#Shortest Path Aggregation
custom_layer = ShortestPathSAGE(in_channels=8, out_channels=2)
transformed_embeddings = custom_layer(node_embedings, edge_index)
assert transformed_embeddings.shape == torch.Size([4, 2]), "Incorrect shape returned."



## Training Pipeline

In [None]:
dataset = TUDataset(root='./data/ENZYMES', name='ENZYMES')
dataset = dataset.shuffle()
train_dataset = dataset[:540]
val_dataset = dataset[540:540 + 60]
test_dataset = dataset[540 + 60:]


In [None]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)
test_loader = DataLoader(test_dataset, batch_size=64)




In [None]:
class GNNWrapper(pl.LightningModule):
    """
    A Pytorch Lightning Wrapper to train
    sequence models (classifiers/regressors)
    """

    def __init__(self, dataset, model, learning_rate = 0.01, use_lr_scheduler = False, weight_decay = 0.0001,
                 schedule_patience = 20, use_node_attr = False, use_edge_attr = False) -> None:
        super().__init__()

        # Store the dataset and define loss
        self.data = dataset
        self.model = model
        self.loss = nn.CrossEntropyLoss()

        # Store hyperparameters
        self.use_node_attr = use_node_attr
        self.use_edge_attr = use_edge_attr
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.use_lr_scheduler = use_lr_scheduler
        self.patience = schedule_patience

        # Utility variables to track the training and validation loss/accuracy
        self.train_loss = []
        self.val_loss = []
        self.train_proba = []
        self.train_labels = []
        self.val_proba = []
        self.val_labels = []
        self.val_hidden = []
        self.val_hidden_plot_patience = 0
        self.epoch_train_loss = []
        self.epoch_train_accuracy = []
        self.epoch_val_loss = []
        self.epoch_val_accuracy = []

        self.save_hyperparameters(ignore=["model", "dataset"])

    def forward(self, x, edge_index, batch):
        """
        We define the forward pass of the model.
        """
        logits = self.model(x, edge_index, batch)
        return logits

    def compute_loss(self, logits, labels) -> torch.Tensor:
        """
        Loss computation, signal from the training nodes
        again for simplicity we consider all other nodes as validation nodes
        """
        # Task specific loss
        train_logits = logits
        train_labels = labels

        loss = self.loss(train_logits, train_labels)

        return loss

    def training_step(self, batch, batch_idx):
        """
        Define a single learning step
        """
        data = batch

        logits = self(data.x, data.edge_index, data.batch)
        loss = self.compute_loss(logits, data.y)

        train_proba = F.softmax(logits, dim=-1).detach().cpu().numpy()
        train_labels = data.y.detach().cpu().numpy()

        self.train_loss.append(loss.detach().cpu().numpy())
        self.train_proba.append(train_proba)
        self.train_labels.append(train_labels)

        return loss

    def on_train_epoch_end(self) -> None:
        """
        At the end of an epoch we compute and log metrics
        """
        train_loss = np.mean(self.train_loss)
        self.log("train/loss", train_loss, prog_bar=True)

        train_proba = np.concatenate(self.train_proba)
        train_labels = np.concatenate(self.train_labels)
        train_acc = accuracy_score(train_labels, np.argmax(train_proba, axis=-1))
        self.log("train/accuracy", train_acc, prog_bar=False)

        self.epoch_train_loss.append(train_loss)
        self.epoch_train_accuracy.append(train_acc)

        self.train_loss.clear()
        self.train_proba.clear()
        self.train_labels.clear()

    def validation_step(self, batch, batch_idx):
        """
        Define a validation step, here very similar to the training step
        For simplicity we consider all nodes not in the training set as validation nodes
        """
        data = batch

        logits = self(data.x, data.edge_index, data.batch)
        loss = self.compute_loss(logits, data.y)

        val_proba = F.softmax(logits, dim=-1).detach().cpu().numpy()
        val_labels = data.y.detach().cpu().numpy()

        self.val_loss.append(loss.detach().cpu().numpy())
        self.val_proba.append(val_proba)
        self.val_labels.append(val_labels)

    def on_validation_epoch_end(self) -> None:
        """
        Compute and log validation metrics, and visualize the hidden layers
        """
        val_loss = float(np.mean(self.val_loss))
        self.log("val/loss", val_loss, prog_bar=True)

        val_proba = np.concatenate(self.val_proba)
        val_labels = np.concatenate(self.val_labels)
        val_acc = accuracy_score(val_labels, np.argmax(val_proba, axis=-1))
        self.log("val/accuracy", val_acc, prog_bar=False)

        self.epoch_val_loss.append(val_loss)
        self.epoch_val_accuracy.append(val_acc)


        self.val_loss.clear()
        self.val_proba.clear()
        self.val_labels.clear()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
        )

        if self.use_lr_scheduler:
            scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.25, patience=self.patience, min_lr=0.0001)
            return {
                'optimizer': optimizer,
                'lr_scheduler': scheduler,
                'monitor': 'train/loss'
            }

        return optimizer


In [None]:
def plot_results(results, seed ):
  fig, axes = plt.subplots(1, 2, figsize=(15, 5))
  key = f"{seed}"
  metrics = results[key]
  epochs = range(1, len(metrics["train_loss"]) + 1)

  # Adjust epochs for the current metric's data
  epochs_for_val_loss = range(1, len(metrics["val_loss"]) + 1)

  # Plot training and validation loss on the left column
  axes[0].plot(epochs, metrics["train_loss"], 'b-', label='Train Loss')
  axes[0].plot(epochs_for_val_loss, metrics["val_loss"], 'r-', label='Validation Loss')
  axes[0].set_xlabel('Epochs')
  axes[0].set_ylabel('Loss')
  axes[0].legend()

  # Adjust epochs for accuracy data
  epochs_for_val_accuracy = range(1, len(metrics["val_accuracy"]) + 1)

  # Plot training and validation accuracy on the right column
  axes[1].plot(epochs, metrics["train_accuracy"], 'b-', label='Train Accuracy')
  axes[1].plot(epochs_for_val_accuracy, metrics["val_accuracy"], 'r-', label='Validation Accuracy')
  axes[1].set_xlabel('Epochs')
  axes[1].set_ylabel('Accuracy')
  axes[1].legend()
  plt.tight_layout()
  plt.show()

  pass

## Residual Network Architectures

In [None]:
#Resiual Netowrk Architectures from Task 3
#Each block uses one of the 4 custom message passing layers
class MeanBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate=0.5, batch_norm=True, residual=False):
        super(MeanBlock, self).__init__()

        self.conv = MeanSAGE(in_channels, out_channels)
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.batch_norm = batch_norm
        self.residual = residual

        if self.residual:
            self.res_connection = torch.nn.Linear(in_channels, out_channels, bias=False)

        if self.batch_norm:
            self.bn = torch.nn.BatchNorm1d(out_channels)

    def forward(self, x, edge_index):
        res = x
        x = self.conv(x, edge_index)
        if self.residual:
            x += self.res_connection(res)
        if self.batch_norm:
            x = self.bn(x)
        x = F.relu(x)
        x = self.dropout(x)
        return x
class LSTMBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate=0.5, batch_norm=True, residual=False):
        super(LSTMBlock, self).__init__()

        self.conv = LSTMSAGE(in_channels, out_channels)
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.batch_norm = batch_norm
        self.residual = residual

        if self.residual:
            self.res_connection = torch.nn.Linear(in_channels, out_channels, bias=False)

        if self.batch_norm:
            self.bn = torch.nn.BatchNorm1d(out_channels)

    def forward(self, x, edge_index):
        res = x
        x = self.conv(x, edge_index)
        if self.residual:
            x += self.res_connection(res)
        if self.batch_norm:
            x = self.bn(x)
        x = F.relu(x)
        x = self.dropout(x)
        return x
class AttnLSTMBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate=0.5, batch_norm=True, residual=False):
        super(AttnLSTMBlock, self).__init__()

        self.conv = AttnLSTMSAGE(in_channels, out_channels)
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.batch_norm = batch_norm
        self.residual = residual

        if self.residual:
            self.res_connection = torch.nn.Linear(in_channels, out_channels, bias=False)

        if self.batch_norm:
            self.bn = torch.nn.BatchNorm1d(out_channels)

    def forward(self, x, edge_index):
        res = x
        x = self.conv(x, edge_index)
        if self.residual:
            x += self.res_connection(res)
        if self.batch_norm:
            x = self.bn(x)
        x = F.relu(x)
        x = self.dropout(x)
        return x
class BaseBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate=0.5, batch_norm=True, residual=False):
        super(BaseBlock, self).__init__()

        self.conv = pyg_nn.GCNConv(in_channels, out_channels)
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.batch_norm = batch_norm
        self.residual = residual

        if self.residual:
            self.res_connection = torch.nn.Linear(in_channels, out_channels, bias=False)

        if self.batch_norm:
            self.bn = torch.nn.BatchNorm1d(out_channels)

    def forward(self, x, edge_index):
        res = x
        x = self.conv(x, edge_index)
        if self.residual:
            x += self.res_connection(res)
        if self.batch_norm:
            x = self.bn(x)
        x = F.relu(x)
        x = self.dropout(x)
        return x

class ResGCN(torch.nn.Module):
    def __init__(self, num_features, layer_configs, num_classes, type = 'Mean'):
        super(ResGCN, self).__init__()

        initial_layer = layer_configs[0]
        LAYER_DICT = {
        'Mean': MeanBlock,
        'LSTM': LSTMBlock,
        'AttnLSTM': AttnLSTMBlock,
        'BASE': BaseBlock,
        }

        self.block = LAYER_DICT[type]

        self.initial = self.block(num_features, initial_layer['out_channels'], initial_layer['dropout_rate'], initial_layer['batch_norm'])

        self.hidden_layers = torch.nn.ModuleList()
        for layer_config in layer_configs[1:]:
            self.hidden_layers.append(self.block(layer_config['in_channels'], layer_config['out_channels'], layer_config['dropout_rate'], layer_config['batch_norm'], residual=True))

        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(layer_configs[-1]['out_channels'], 64),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.5),
            torch.nn.BatchNorm1d(64),
            torch.nn.Linear(64, num_classes),
        )

    def forward(self, x, edge_index, batch):
        x = self.initial(x, edge_index)
        for layer in self.hidden_layers:
            x = layer(x, edge_index)
        x = global_mean_pool(x, batch)
        x = self.mlp(x)
        return x

layer_configs = [
    {"out_channels": 128, "dropout_rate": 0.5, "batch_norm": True, "residual": True},
    {"in_channels": 128, "out_channels": 128, "dropout_rate": 0.5, "batch_norm": True, "residual": True}
]


## Training and Results

### Base Performance

In [None]:
DATA_PERTURB_SEEDS = [42]  # 3 distinct seeds for data perturbations

# hyperparameters
train_size = 500
test_size = 100
batch_size = 256
learning_rate = 0.001
weight_decay = 0.005
scheduler = True
schedule_patience = 10


# Load dataset
dataset = TUDataset(name='ENZYMES', root='data/TUDataset')
dataset.use_node_attr = True
dataset.use_edge_attr = True

base_results = {}

for perturb_seed in DATA_PERTURB_SEEDS:

    # Set seed
    torch.manual_seed(perturb_seed)
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Instantiate training wrapper
    model = ResGCN(dataset.num_features, layer_configs, dataset.num_classes, type = 'BASE')
    model_wrapper = GNNWrapper(dataset, model, learning_rate=learning_rate, use_lr_scheduler = scheduler,
                                schedule_patience = schedule_patience, weight_decay=weight_decay)

    # Early stopping to prevent overfitting
    early_stopping = EarlyStopping(monitor=f"val/loss", mode="min", patience=50)

    # Instantiate Pytorch Lightning trainer
    trainer = pl.Trainer(
      accelerator="cpu",
      max_epochs=100,
      callbacks=[early_stopping],
      check_val_every_n_epoch=1,
    )

    # Train the model
    trainer.fit(model_wrapper, train_loader, test_loader)
    trainer.validate(model_wrapper, test_loader)

    # Save the results
    key = f"{perturb_seed}"

    base_results[key] = {
        "train_loss": model_wrapper.epoch_train_loss,
        "train_accuracy": model_wrapper.epoch_train_accuracy,
        "val_loss": model_wrapper.epoch_val_loss,
        "val_accuracy": model_wrapper.epoch_val_accuracy,
        "best_val_accuracy": max(model_wrapper.epoch_val_accuracy)
      }

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.

  | Name  | Type             | Params
-------------------------------------------
0 | model | ResGCN           | 42.7 K
1 | loss  | CrossEntropyLoss | 0     
-------------------------------------------
42.7 K    Trainable params
0         Non-trainable params
42.7 K    Total params
0.171     Total estimated model params size (MB)


                                                                           

/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 99: 100%|██████████| 2/2 [00:00<00:00, 17.99it/s, v_num=104, val/loss=1.630, train/loss=1.630]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 2/2 [00:00<00:00, 17.36it/s, v_num=104, val/loss=1.630, train/loss=1.630]

/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.



Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 104.22it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val/accuracy                 0.32
        val/loss             1.628420114517212
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


In [None]:
best_val_accuracies = np.array([metrics["best_val_accuracy"] for _, metrics in base_results.items()])
mean_accuracy = np.mean(best_val_accuracies)
std_dev_accuracy = np.std(best_val_accuracies)
print(f"Best Validation Accuracies: {best_val_accuracies}")
print(f"Mean Best Validation Accuracy: {mean_accuracy:.4f}")
print(f"Standard Deviation of Best Validation Accuracies: {std_dev_accuracy:.4f}")

Best Validation Accuracies: [0.35]
Mean Best Validation Accuracy: 0.3500
Standard Deviation of Best Validation Accuracies: 0.0000


### Mean Residual GCN

In [None]:
DATA_PERTURB_SEEDS = [42]  # 3 distinct seeds for data perturbations

# hyperparameters
train_size = 500
test_size = 100
batch_size = 256
learning_rate = 0.01
weight_decay = 0.005
scheduler = True
schedule_patience = 10


# Load dataset
dataset = TUDataset(name='ENZYMES', root='data/TUDataset')
dataset.use_node_attr = True
dataset.use_edge_attr = True

lstm_results = {}

for perturb_seed in DATA_PERTURB_SEEDS:

    # Set seed
    torch.manual_seed(perturb_seed)
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Instantiate training wrapper
    model = ResGCN(dataset.num_features, layer_configs, dataset.num_classes, type = 'Mean')
    model_wrapper = GNNWrapper(dataset, model, learning_rate=learning_rate, use_lr_scheduler = scheduler,
                                schedule_patience = schedule_patience, weight_decay=weight_decay)

    # Early stopping to prevent overfitting
    early_stopping = EarlyStopping(monitor=f"val/loss", mode="min", patience=50)

    # Instantiate Pytorch Lightning trainer
    trainer = pl.Trainer(
      accelerator="cpu",
      max_epochs=100,
      callbacks=[early_stopping],
      check_val_every_n_epoch=1,
    )

    # Train the model
    trainer.fit(model_wrapper, train_loader, test_loader)
    trainer.validate(model_wrapper, test_loader)

    # Save the results
    key = f"{perturb_seed}"

    lstm_results[key] = {
        "train_loss": model_wrapper.epoch_train_loss,
        "train_accuracy": model_wrapper.epoch_train_accuracy,
        "val_loss": model_wrapper.epoch_val_loss,
        "val_accuracy": model_wrapper.epoch_val_accuracy,
        "best_val_accuracy": max(model_wrapper.epoch_val_accuracy)
      }

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.

  | Name  | Type             | Params
-------------------------------------------
0 | model | ResGCN           | 59.7 K
1 | loss  | CrossEntropyLoss | 0     
-------------------------------------------
59.7 K    Trainable params
0         Non-trainable params
59.7 K    Total params
0.239     Total estimated model params size (MB)


                                                                            

/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 99: 100%|██████████| 2/2 [00:00<00:00, 20.02it/s, v_num=91, val/loss=1.610, train/loss=1.360]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 2/2 [00:00<00:00, 19.18it/s, v_num=91, val/loss=1.610, train/loss=1.360]

/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.



Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 117.60it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val/accuracy                 0.38
        val/loss            1.6121156215667725
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


In [None]:
best_val_accuracies = np.array([metrics["best_val_accuracy"] for _, metrics in mean_results.items()])
mean_accuracy = np.mean(best_val_accuracies)
std_dev_accuracy = np.std(best_val_accuracies)
print(f"Best Validation Accuracies: {best_val_accuracies}")
print(f"Mean Best Validation Accuracy: {mean_accuracy:.4f}")
print(f"Standard Deviation of Best Validation Accuracies: {std_dev_accuracy:.4f}")

Best Validation Accuracies: [0.36]
Mean Best Validation Accuracy: 0.3600
Standard Deviation of Best Validation Accuracies: 0.0000


### LSTM Residual Model

In [None]:
DATA_PERTURB_SEEDS = [42]  # 3 distinct seeds for data perturbations

# hyperparameters
train_size = 500
test_size = 100
batch_size = 256
learning_rate = 0.01
weight_decay = 0.005
scheduler = True
schedule_patience = 10


# Load dataset
dataset = TUDataset(name='ENZYMES', root='data/TUDataset')
dataset.use_node_attr = True
dataset.use_edge_attr = True

lstm_results = {}

for perturb_seed in DATA_PERTURB_SEEDS:

    # Set seed
    torch.manual_seed(perturb_seed)
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Instantiate training wrapper
    model = ResGCN(dataset.num_features, layer_configs, dataset.num_classes, type = 'LSTM')
    model_wrapper = GNNWrapper(dataset, model, learning_rate=learning_rate, use_lr_scheduler = scheduler,
                                schedule_patience = schedule_patience, weight_decay=weight_decay)

    # Early stopping to prevent overfitting
    early_stopping = EarlyStopping(monitor=f"val/loss", mode="min", patience=50)

    # Instantiate Pytorch Lightning trainer
    trainer = pl.Trainer(
      accelerator="cpu",
      max_epochs=100,
      callbacks=[early_stopping],
      check_val_every_n_epoch=1,
    )

    # Train the model
    trainer.fit(model_wrapper, train_loader, test_loader)
    trainer.validate(model_wrapper, test_loader)

    # Save the results
    key = f"{perturb_seed}"

    lstm_results[key] = {
        "train_loss": model_wrapper.epoch_train_loss,
        "train_accuracy": model_wrapper.epoch_train_accuracy,
        "val_loss": model_wrapper.epoch_val_loss,
        "val_accuracy": model_wrapper.epoch_val_accuracy,
        "best_val_accuracy": max(model_wrapper.epoch_val_accuracy)
      }

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.

  | Name  | Type             | Params
-------------------------------------------
0 | model | ResGCN           | 259 K 
1 | loss  | CrossEntropyLoss | 0     
-------------------------------------------
259 K     Trainable params
0         Non-trainable params
259 K     Total params
1.040     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]

/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


                                                                           

/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 99: 100%|██████████| 2/2 [00:05<00:00,  0.39it/s, v_num=90, val/loss=1.660, train/loss=1.610]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 2/2 [00:05<00:00,  0.39it/s, v_num=90, val/loss=1.660, train/loss=1.610]

/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.



Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  3.87it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val/accuracy                 0.29
        val/loss             1.657952070236206
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


In [None]:
best_val_accuracies = np.array([metrics["best_val_accuracy"] for _, metrics in lstm_results.items()])
mean_accuracy = np.mean(best_val_accuracies)
std_dev_accuracy = np.std(best_val_accuracies)
print(f"Best Validation Accuracies: {best_val_accuracies}")
print(f"Mean Best Validation Accuracy: {mean_accuracy:.4f}")
print(f"Standard Deviation of Best Validation Accuracies: {std_dev_accuracy:.4f}")

Best Validation Accuracies: [0.42]
Mean Best Validation Accuracy: 0.4200
Standard Deviation of Best Validation Accuracies: 0.0000


### Attention Ordered LSTM Aggregation Model

In [None]:
DATA_PERTURB_SEEDS = [42]  # 3 distinct seeds for data perturbations

# hyperparameters
train_size = 500
test_size = 100
batch_size = 256
learning_rate = 0.0001
weight_decay = 0.0005
scheduler = True
schedule_patience = 10


# Load dataset
dataset = TUDataset(name='ENZYMES', root='data/TUDataset')
dataset.use_node_attr = True
dataset.use_edge_attr = True

attn_lstm_results = {}

for perturb_seed in DATA_PERTURB_SEEDS:

    # Set seed
    torch.manual_seed(perturb_seed)

    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Instantiate training wrapper
    model = ResGCN(dataset.num_features, layer_configs, dataset.num_classes, type = 'AttnLSTM')
    model_wrapper = GNNWrapper(dataset, model, learning_rate=learning_rate, use_lr_scheduler = scheduler,
                                schedule_patience = schedule_patience, weight_decay=weight_decay)

    # Early stopping to prevent overfitting
    early_stopping = EarlyStopping(monitor=f"val/loss", mode="min", patience=50)

    # Instantiate Pytorch Lightning trainer
    trainer = pl.Trainer(
      accelerator="cpu",
      max_epochs=100,
      callbacks=[early_stopping],
      check_val_every_n_epoch=1,
    )

    # Train the model
    trainer.fit(model_wrapper, train_loader, test_loader)
    trainer.validate(model_wrapper, test_loader)

    # Save the results
    key = f"{perturb_seed}"

    attn_lstm_results[key] = {
        "train_loss": model_wrapper.epoch_train_loss,
        "train_accuracy": model_wrapper.epoch_train_accuracy,
        "val_loss": model_wrapper.epoch_val_loss,
        "val_accuracy": model_wrapper.epoch_val_accuracy,
        "best_val_accuracy": max(model_wrapper.epoch_val_accuracy)
      }

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.

  | Name  | Type             | Params
-------------------------------------------
0 | model | ResGCN           | 242 K 
1 | loss  | CrossEntropyLoss | 0     
-------------------------------------------
242 K     Trainable params
0         Non-trainable params
242 K     Total params
0.971     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]

/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


                                                                           

/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 67: 100%|██████████| 2/2 [00:05<00:00,  0.35it/s, v_num=100, val/loss=1.810, train/loss=1.860]

/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.



Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  3.42it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val/accuracy                 0.16
        val/loss            1.8188143968582153
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


In [None]:
best_val_accuracies = np.array([metrics["best_val_accuracy"] for _, metrics in attn_lstm_results.items()])
mean_accuracy = np.mean(best_val_accuracies)
std_dev_accuracy = np.std(best_val_accuracies)
print(f"Best Validation Accuracies: {best_val_accuracies}")
print(f"Mean Best Validation Accuracy: {mean_accuracy:.4f}")
print(f"Standard Deviation of Best Validation Accuracies: {std_dev_accuracy:.4f}")

Best Validation Accuracies: [0.24]
Mean Best Validation Accuracy: 0.2400
Standard Deviation of Best Validation Accuracies: 0.0000
