<a href="https://colab.research.google.com/github/AchrafAsh/gnn-receptive-fields/blob/main/fine_tuning_on_enzymes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

ENZYMES is a dataset of long-range graphs (molecules) for graph classification. This is what make this experiment interesting.

For recall:
- over-smoothing happens on short-range graphs (because you can't apply too many layers)
- over-squashing happens on long-range graphs

In my opinion, over-smoothing should happen on long-range graphs where a lot of layers are required (at least 8 layers for a problem radius of 8) which will ultimately lead to over-smoothing. Whereas in short-range graphs, we have the luxury of smoothing a lot because close neighbors are similar.

## **🚀Setting up**

In [None]:
import os, sys
import os.path as osp
from google.colab import drive
drive.mount('/content/mnt')
nb_path = '/content/notebooks'
try:
    os.symlink('/content/mnt/My Drive/Colab Notebooks', nb_path)
except:
    pass
sys.path.insert(0, nb_path)  # or append(nb_path)

Drive already mounted at /content/mnt; to attempt to forcibly remount, call drive.mount("/content/mnt", force_remount=True).


In [None]:
# import everything
import math
import random
import copy
import time
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
import networkx as nx
import yaml

from functools import partial
from tqdm.notebook import tqdm
from typing import Dict, List, Tuple
from torch_geometric.utils import degree, to_dense_adj, dense_to_sparse, add_self_loops, to_networkx
from torch_geometric.nn import GCNConv, MessagePassing, Sequential
from torch_sparse import spmm, spspmm
from sklearn.metrics import pairwise_distances
from sklearn.manifold import TSNE

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

%matplotlib inline
sns.set_style('whitegrid')

In [None]:
%%capture
!wget https://raw.githubusercontent.com/AchrafAsh/gnn-receptive-fields/main/data.py

from data import load_dataset

## **🎨Designing the model**

In [None]:
# Parameter initialization
def xavier(tensor):
    if tensor is not None:
        stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-2)))
        tensor.data.uniform_(-stdv, stdv)


def zeros(tensor):
    if tensor is not None:
        tensor.data.fill_(0)

In [None]:
class OurGCNConv(MessagePassing):
    def __init__(self, num_features:int, in_channels:int, out_channels:int, k:int):
        super().__init__(aggr='add')  # "Add" aggregation
        self.k = k
        self.lin_neb = torch.nn.Linear(num_features, out_channels)
        self.lin_trgt = torch.nn.Linear(in_channels, out_channels)
        
        self.reset_parameters()
        
    def reset_parameters(self):
        xavier(self.lin_neb.weight)
        zeros(self.lin_neb.bias)
        
        xavier(self.lin_trgt.weight)
        zeros(self.lin_trgt.bias)

    def forward(self, x, h, edge_index):
        # x is the input features and has shape [N, num_features]
        # h is the hidden state and has shape [N, in_channels]
        # edge_index has shape [2, E] , E being the number of edges

        # step 1: linearly transform node feature matrices
        x = self.lin_neb(x)
        h = self.lin_trgt(h)

        # step 3-5: start propagating messages
        return self.propagate(edge_index[self.k].to(device), x=x, h=h)

    def message(self, x_j, h_i, edge_index, size):
        # x_j is the input features of the neighbors and has shape [E, out_channels] (has already been multiplied by the weight matrix)

        # step 3: normalize node features
        row, col = edge_index
        deg = degree(row, size[0], dtype=x_j.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        out = norm.view(-1, 1) * x_j

        return out + h_i

    def update(self, aggr_out):
        # aggr_out has shape [N, out_channels] is the output of self.message()

        # step 5: return new node embeddings
        return aggr_out

In [None]:
# the real deal
class OurModel(torch.nn.Module):
    def __init__(self, num_layers:int, hidden_dim:int, num_features:int, 
                 num_classes:int, propagation_steps:int=2, dropout:float=0.5):
        super().__init__()
        self.propagation_steps = propagation_steps
        
        self.alpha = torch.nn.Parameter(torch.tensor(0.5), requires_grad=True)
        # Embedding input features
        self.in_mlp = torch.nn.Sequential(
            torch.nn.Linear(in_features=num_features, out_features=hidden_dim),
            torch.nn.ReLU()
        )
        # Convolutional layers
        self.conv_layers = self.create_layers(num_layers=num_layers,
                                              hidden_dim=hidden_dim,
                                              dropout=dropout)
        # Readout function
        self.readout = torch.nn.Sequential(
            torch.nn.Linear(in_features=hidden_dim, out_features=num_classes),
            torch.nn.LogSoftmax(dim=1)
        )

    def create_layers(self, num_layers:int, hidden_dim:int, dropout:float):
        layers = [(OurGCNConv(num_features=hidden_dim, in_channels=hidden_dim, out_channels=hidden_dim, k=0), "x, x, edge_index -> h"),
                (torch.nn.ReLU(inplace=True)),
                (torch.nn.Dropout(p=dropout), "h -> h")]
        
        for k in range(1, num_layers):
            layers += [
                (OurGCNConv(num_features=hidden_dim, in_channels=hidden_dim, out_channels=hidden_dim, k=k), "x, h, edge_index -> h"),
                # (GCNConv(hidden_dim, hidden_dim), "h, edge_index -> h"),
                (torch.nn.ReLU(inplace=True)),
                (torch.nn.Dropout(p=dropout), "h -> h")
            ]
        return Sequential("x, edge_index", layers)


    def reset_parameters(self):
        self.conv_layers.reset_parameters()


    def propagate(self, x, edge_index):
        # add self loops
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        
        # normalize
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # dense_adj = to_dense_adj(edge_index).squeeze(0)
        # edge_index_pow = torch.eye(G.num_nodes).to(device)
        # prop_repr = edge_index_pow.clone()
        # for _ in range(self.propagation_steps):
        #     edge_index_pow = torch.mm(dense_adj, edge_index_pow)
        #     prop_repr += edge_index_pow
        # return torch.mm(prop_repr, x)
        
        # APPNP propagation scheme
        z = x.clone()
        for _ in range(self.propagation_steps):
            z = spmm(edge_index, norm, x.size(0), x.size(0), z) * (1-self.alpha) + x * self.alpha
        return z

        # another propagation scheme (sum of the powers of A)
        # props = []
        # for _ in range(self.propagation_steps):
        #     x= spmm(edge_index, norm, x.size(0), x.size(0), x)
        #     props.append(x)
        # return sum(props)


    def forward(self, x, edge_index):
        embeddings = self.in_mlp(x)
        h = self.conv_layers(embeddings, edge_index)
        out = self.propagate(h, edge_index[0].to(device))
        return h, self.readout(out)

## **🧰Utility functions**

In [None]:
def get_k_neighbors(k: int, edge_index: torch.Tensor):
    """Returns the l-hop neighbors for l between 1 (the adjacency matrix) and k (given depth)

    Args:
        - k (int): size of the maximum neighborhood
    """
    
    output = [edge_index]
    dense_adj = to_dense_adj(edge_index).squeeze(0)
    dense_nebs = [dense_adj.clone()]
    adj_pow = dense_adj.clone()

    for l in tqdm(range(1, k)):
        adj_pow = torch.mm(dense_adj, adj_pow)
        k_neb = torch.where(
            torch.where(adj_pow > 0, 1, 0) - sum(dense_nebs) > 0,
            1,
            0
        )
        dense_nebs.append(k_neb)
        output.append(dense_to_sparse(k_neb)[0])
    
    return output

In [None]:
# count model parameters
def count_parameters(model: torch.nn.Module):
    total_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"The model has {total_parameters:,} parameters")

    return total_parameters

In [None]:
def make(config):
    # Make the model
    model = OurModel(num_layers=config['num_layers'],
                     hidden_dim=config['hidden_dim'],
                     num_features=cora_dataset.num_features,
                     num_classes=cora_dataset.num_classes,
                     propagation_steps=config['propagation_steps'],
                     dropout=config['dropout']).to(device)

    # Make the loss and optimizer
    criterion = torch.nn.NLLLoss()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=config['learning_rate'],
                                 weight_decay=config['weight_decay'])
    
    return model, criterion, optimizer

In [None]:
def train(model, all_edge_index, criterion, optimizer, config):
    outputs = []
    
    for _ in range(config['runs']):
        for epoch in tqdm(range(config['epochs'])):
            total_loss = 0
            for graph in data:
                loss = train_step(model, all_edge_index, graph.x, optimizer, criterion)
                total_loss += loss
            
            # test the model
            outs = test(model, all_edge_index, data, criterion)
            outs['epoch'] = epoch
            outs['id'] = config['id']
            outs['hidden_dim'] = config['hidden_dim']
            outs['weight_decay'] = config['weight_decay']
            outs['num_layers'] = config['num_layers']
            outs['learning_rate'] = config['learning_rate']
            outs['dropout'] = config['dropout']
            outs['propagation_steps'] = config['propagation_steps']

            outputs.append(outs)

    return pd.DataFrame(outputs)

In [None]:
def train_step(model, all_edge_index, x, optimizer, criterion):
    """Performs one training step
    """
    model.train()
    
    # Forward pass
    _, out = model(x.to(device), all_edge_index)
    loss = criterion(out[G.train_mask], G.y[G.train_mask])
    
    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    return loss

In [None]:
def test(model, all_edge_index, criterion, metrics=[]):
    """
    Metrics is a list of tuple ('metric_name', metric_func) where the metric 
    function takes the last representation matrix and returns a scalar.
    """

    model.eval()

    # Run the model on some test examples
    with torch.no_grad():
        h, logits = model(G.x, all_edge_index)

    outs = {}
    h = h.detach().cpu()
    for (name, metric) in metrics:
        outs[name] = metric(h)

    for key in ['train', 'val', 'test']:
        mask = G[f'{key}_mask']
        loss = criterion(logits[mask], G.y[mask]).item()
        pred = logits[mask].max(1)[1]
        acc = pred.eq(G.y[mask]).sum().item() / mask.sum().item()

        outs[f'{key}_loss'] = loss
        outs[f'{key}_acc'] = acc
    
    return outs

In [None]:
def model_pipeline(config):
    # create the model
    model, criterion, optimizer = make(config)
    config['total_parameters'] = count_parameters(model)
    
    # compute different depth edge_index
    all_edge_index = get_k_neighbors(config['num_layers'])

    # train the model for different parameters
    logs = train(model, all_edge_index, criterion, optimizer, config)

    # repr = tsne_plot(model, all_edge_index, title="Last hidden representations")

    return model, logs, config

## **🕸 ENZYMES**

In [None]:
%%capture
path = osp.join(os.getcwd(), 'data')
enzymes = load_dataset(path, 'ENZYMES')

## **🔍Fine tuning**

In [None]:
grid = dict(
    num_layers=[1],
    hidden_dim=[16,24],
    propagation_steps=[6,7,8,9],
    learning_rate=np.logspace(-3, -2, 5),
    weight_decay=np.logspace(-3, -1, 5),
    dropout=np.linspace(0.4,0.5, 4),
    epochs=200,
    runs=2
)

In [None]:
id = 0
all_logs=None

for num_layers in grid['num_layers']:
    for hidden_dim in grid['hidden_dim']:
        for propagation_steps in grid['propagation_steps']:
            for learning_rate in grid['learning_rate']:
                for weight_decay in grid['weight_decay']:
                    for dropout in grid['dropout']:
                        model, logs, hyperparameters = model_pipeline({
                            'id':id,
                            'num_layers':num_layers,
                            'hidden_dim':hidden_dim,
                            'propagation_steps':propagation_steps,
                            'learning_rate':learning_rate,
                            'weight_decay':weight_decay,
                            'dropout':dropout,
                            'epochs':200,
                            'runs':2
                        })

                        print(logs.query(f'id == {id} & epoch == 199'))

                        if id == 0:
                            all_logs = logs
                        else:
                            all_logs = pd.concat([all_logs, logs], ignore_index=True)
                        
                        id += 1

In [None]:
all_logs.query('epoch == 199 & train_acc == 1.0 & test_acc > 0.8 & val_acc > 0.76').sort_values(by=["test_acc", "val_acc"], ascending=False)