<a href="https://colab.research.google.com/github/AchrafAsh/gnn-receptive-fields/blob/main/02_benchmark.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Setup environment

In [None]:
# setup colab environment
import os, sys
import os.path as osp
from google.colab import drive
drive.mount('/content/mnt')
nb_path = '/content/notebooks'
os.symlink('/content/mnt/My Drive/Colab Notebooks', nb_path)
sys.path.insert(0, nb_path)  # or append(nb_path)

### Load Cora dataset

In [None]:
%%capture
!wget https://raw.githubusercontent.com/AchrafAsh/gnn-receptive-fields/main/data.py

In [None]:
from data import load_dataset

path = osp.join(os.getcwd(), 'data')
cora_dataset = load_dataset(path, 'Cora')
G = cora_dataset[0] # only graph of the dataset

## Utils

In [None]:
import torch
import torch.nn.functional as F

from typing import Tuple, Dict, List
from tqdm import tqdm

In [None]:
def mean_average_distance(x, mask=None) -> float:
    D = 1 - torch.div(torch.matmul(h, torch.transpose(x, 0, 1)),
                      torch.matmul(torch.transpose(torch.norm(x, dim=1, keepdim=True), 0, 1), torch.norm(x, dim=1, keepdim=True)))
    if mask is not None:
        D = D * mask
    D_hat = torch.div( torch.sum(D, dim=1), torch.sum(D > 0, dim=1) )
    return torch.sum(D_hat) / torch.sum(D_hat > 0)

In [None]:
def count_parameters(model: torch.nn.Module):
    print(f"The model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} parameters")

In [None]:
def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer,
          data: type(G), edge_index) -> Tuple[torch.tensor, float]:
    
    model.train()
    optimizer.zero_grad()

    hidden_state, out = model(data.x, edge_index)
    
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return hidden_state, loss

In [None]:
def evaluate(model: torch.nn.Module, data: type(G), edge_index) -> Dict[str, float]:
    
    model.eval()

    with torch.no_grad():
        _, logits = model(data.x, edge_index)
    
    outs = {}
    for key in ['train', 'val', 'test']:
        mask = data[f'{key}_mask']
        loss = F.nll_loss(logits[mask], data.y[mask]).item()
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()

        outs[f'{key}_loss'] = loss
        outs[f'{key}_acc'] = acc
    return outs

In [None]:
def run(data: type(G), model: torch.nn.Module,
        edge_index,
        runs: int, epochs: int, lr: float,
        weight_decay: float,
        early_stopping: int = 0,
        initialize:bool=True) -> Tuple[List[float], List[float], List[float]]:

    val_losses, accs, durations = [], [], []
    for _ in range(runs):
        data = data.to(device)
        model.to(device)
        if initialize:
            model.reset_parameters()
        
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, 
                                     weight_decay=weight_decay)

        if torch.cuda.is_available():
            torch.cuda.synchronize()
        
        t_start = time.perf_counter()

        best_val_loss = float('inf')
        val_loss_history, train_losses, test_accs, MADs = [], [], [], []

        for epoch in range(1, epochs+1):
            hidden_state, train_loss = train(model, optimizer, data, edge_index)

            MAD = mean_average_distance(x=hidden_state).item()
            eval_info = evaluate(model, data, edge_index)
            
            test_acc = eval_info['test_acc']
            val_loss_history.append(eval_info['val_loss'])
            train_losses.append(train_loss.item())
            test_accs.append(test_acc)
            MADs.append(MAD)
        
            # I don't understand what this is for...
            if eval_info['val_loss'] < best_val_loss:
                best_val_loss = eval_info['val_loss']
            
            if early_stopping > 0 and epoch > epochs // 2:
                tmp = torch.tensor(val_loss_history[-(early_stopping + 1):-1])
                if eval_info['val_loss'] > tmp.mean().item():
                    break
            
            if epoch % 10 == 0:
                print(f"Epoch: [{epoch} / {epochs}] | Loss: {train_loss} | Test accuracy: {test_acc} | MAD: {MAD}")

        if torch.cuda.is_available():
            torch.cuda.synchronize()

        t_end = time.perf_counter()

        val_losses.append(best_val_loss)
        accs.append(test_acc)
        durations.append(t_end - t_start)

    loss, acc, duration = torch.tensor(val_losses), torch.tensor(accs), torch.tensor(durations)

    print(f"Val Loss: {loss.mean().item():.4f}, Test Accuracy: {acc.mean().item():.3f} ± {acc.std().item():.3f}, Duration: {duration.mean().item():.3f}")
    return train_losses, test_accs, MADs

## Benchmark Models

### Simple GCN

In [None]:
from torch_geometric.nn import Sequential, GCNConv

In [None]:
class GCN(torch.nn.Module):
    def __init__(self, num_layers:int, hidden_dim:int, num_features:int, num_classes:int, dropout:float=0.5):
        super().__init__()
        self.conv_layers = self.create_layers(num_layers=num_layers, 
                                              num_features=num_features,
                                              num_classes=num_classes,
                                              hidden_dim=hidden_dim,
                                              dropout=dropout)
        
        self.log_softmax = torch.nn.LogSoftmax(dim=1)


    def create_layers(self, num_layers:int, num_features:int, num_classes:int, 
                      hidden_dim:int, dropout:float):
        layers = []

        # first layer
        layers += [
                (GCNConv(in_channels=num_features, out_channels=hidden_dim, k=0), "x, edge_index -> x"),
                (torch.nn.ReLU(), "x -> x"),
                (torch.nn.Dropout(p=dropout), "x -> x")
        ]

        for k in range(1, num_layers-1):
            layers += [
                (GCNConv(in_channels=hidden_dim, out_channels=hidden_dim, k=k), "x, edge_index -> x"),
                (torch.nn.ReLU(), "x -> x"),
                (torch.nn.Dropout(p=dropout), "x -> x")
            ]
        
        # last layer
        layers += [
            (GCNConv(in_channels=hidden_dim, out_channels=num_classes, k=num_layers-1), "x, edge_index -> x"),
        ]
        return Sequential("x, edge_index", layers)


    def reset_parameters(self):
        self.conv_layers.reset_parameters()


    def forward(self, x, edge_index):
        h = self.conv_layers(x, edge_index)
        return h, self.log_softmax(h)

## Experiments

### GCN

In [None]:
NUM_LAYERS=8
HIDDEN_DIM=16
NUM_FEATURES=cora_dataset.num_features
NUM_CLASSES=cora_dataset.num_classes
EPOCHS=200
LR=0.01
WEIGHT_DECAY=5e-4
EARLY_STOPPING=0
RUNS=1

In [None]:
model = GCN(num_layers=NUM_LAYERS, hidden_dim=HIDDEN_DIM, num_features=NUM_FEATURES, num_classes=NUM_CLASSES)
count_parameters(model)

In [None]:
run(data=G, model=model,
    edge_index=G.edge_index,
    runs=RUNS, epochs=EPOCHS,
    lr=LR,
    weight_decay=WEIGHT_DECAY,
    early_stopping=EARLY_STOPPING)