# Pyramid Graph Neural Network

In [15]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
# System imports
import os
import sys
import yaml

# External imports
import matplotlib.pyplot as plt
import scipy as sp
from sklearn.decomposition import PCA
from sklearn.metrics import auc
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger

sys.path.append('..')
device = "cuda" if torch.cuda.is_available() else "cpu"

## Model Definition

In [17]:
from LightningModules.GNN.gnn_base import GNNBase
from LightningModules.GNN.utils import load_dataset, random_edge_slice_v2, make_mlp
from torch.utils.checkpoint import checkpoint
from torch_scatter import scatter_add

In [4]:
from LightningModules.GNN.Models.checkpoint_pyramid import CheckpointedPyramid

In [18]:
with open("example_gnn.yaml") as f:
        hparams = yaml.load(f, Loader=yaml.FullLoader)
        
hparams["first_layer"] = 128
hparams["second_layer"] = 64

In [19]:
class CheckpointedResAGNN(GNNBase):
    def __init__(self, hparams):
        super().__init__(hparams)
        """
        Initialise the Lightning Module that can scan over different GNN training regimes
        """

        # Setup input network
        self.node_encoder = make_mlp(
            hparams["in_channels"],
            [hparams["first_layer"], hparams["second_layer"]],
            output_activation=hparams["hidden_activation"],
            layer_norm=hparams["layernorm"],
        )

        # The edge network computes new edge features from connected nodes
        self.edge_network = make_mlp(
            2 * hparams["second_layer"],
            [hparams["first_layer"], hparams["second_layer"], 1],
            layer_norm=hparams["layernorm"],
            output_activation=None,
            hidden_activation=hparams["hidden_activation"],
        )

        # The node network computes new node features
        self.node_network = make_mlp(
            hparams["second_layer"],
            [hparams["first_layer"], hparams["second_layer"]],
            layer_norm=hparams["layernorm"],
            output_activation=hparams["hidden_activation"],
            hidden_activation=hparams["hidden_activation"],
        )

    def forward(self, x, edge_index):

        # Encode the graph features into the hidden space
#         input_x = x
        x = self.node_encoder(x)

        start, end = edge_index

        # Loop over iterations of edge and node networks
        for i in range(self.hparams["n_graph_iters"]):
            # Previous hidden state
            x0 = x

            # Compute new edge score
            edge_inputs = torch.cat([x[start], x[end]], dim=1)
            e = checkpoint(self.edge_network, edge_inputs)
            e = torch.sigmoid(e)

            # Sum weighted node features coming into each node
            #             weighted_messages_in = scatter_add(e * x[start], end, dim=0, dim_size=x.shape[0])
            #             weighted_messages_out = scatter_add(e * x[end], start, dim=0, dim_size=x.shape[0])

            weighted_messages = scatter_add(
                e * x[start], end, dim=0, dim_size=x.shape[0]
            ) + scatter_add(e * x[end], start, dim=0, dim_size=x.shape[0])

            # Compute new node features
            #             node_inputs = torch.cat([x, weighted_messages_in, weighted_messages_out], dim=1)
            node_inputs = torch.cat([x, weighted_messages], dim=1)
            print(node_inputs.shape)
            x = checkpoint(self.node_network, node_inputs)

            # Residual connection
            x = x + x0

        # Compute final edge scores; use original edge directions only
        clf_inputs = torch.cat([x[start], x[end]], dim=1)
        return checkpoint(self.edge_network, clf_inputs).squeeze(-1)

In [20]:
model = CheckpointedResAGNN(hparams)

In [6]:
model = CheckpointedPyramid(hparams)

In [21]:
def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad)

### Memory Test

In [22]:
%%time
model.setup(stage="fit")

CPU times: user 2.23 s, sys: 189 ms, total: 2.42 s
Wall time: 1.47 s


In [23]:
sample = model.trainset[0].to(device)

In [24]:
model = model.to(device)

In [25]:
torch.cuda.reset_peak_memory_stats()
output = model(sample.x.to(device), sample.edge_index.to(device))

torch.Size([341982, 128])
torch.Size([341982, 128])
torch.Size([341982, 128])
torch.Size([341982, 128])
torch.Size([341982, 128])
torch.Size([341982, 128])
torch.Size([341982, 128])
torch.Size([341982, 128])


In [13]:
print(f"{hparams['first_layer']}:{hparams['second_layer']} - {torch.cuda.max_memory_allocated()/1024**3} - {count_parameters(model)}")

96:96 - 8.90728759765625 - 57409


In [12]:
print(f"{hparams['first_layer']}:{hparams['second_layer']} - {torch.cuda.max_memory_allocated()/1024**3} - {count_parameters(model)}")

128:64 - 6.419201374053955 - 51329


In [13]:
print(f"{hparams['hidden']}:{hparams['hidden']/2} - {torch.cuda.max_memory_allocated()/1024**3} - {count_parameters(model)}")

128:64.0 - 6.419201374053955 - 51329


### Train GNN

In [6]:
logger = WandbLogger(project="ITk_1GeV_GNN", group="InitialTest")
trainer = Trainer(gpus=1, max_epochs=10, logger=logger)
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[34m[1mwandb[0m: Currently logged in as: [33mmurnanedaniel[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.0 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Set SLURM handle signals.

  | Name          | Type        | Params
----------------------------------------------
0 | input_network | Sequential  | 2.4 K 
1 | edge_network  | EdgeNetwork | 4.4 K 
2 | node_network  | NodeNetwork | 4.3 K 
----------------------------------------------
11.2 K    Trainable params
0         Non-trainable params
11.2 K    Total params
0.045     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
  eff = torch.tensor(edge_true_positive / edge_true)
  pur = torch.tensor(edge_true_positive / edge_positive)
  f'The dataloader, {name}, does not have many workers which may be a bottleneck.'


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...')


KeyboardInterrupt: 