# Explore PyG Autoscale

In [1]:
%load_ext autoreload
%autoreload 2

# System imports
import os
import sys
import yaml

# External imports
import matplotlib.pyplot as plt
import scipy as sp
from sklearn.decomposition import PCA
from sklearn.metrics import auc
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from pytorch_lightning import Trainer
from torch_sparse import SparseTensor
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from torch_geometric.data import Data, Batch
from torch_geometric.transforms import ToSparseTensor

sys.path.append("../..")
device = "cuda" if torch.cuda.is_available() else "cpu"

In [73]:
from torch_geometric_autoscale.models import GAT, ScalableGNN
from torch_geometric_autoscale import metis, permute, SubgraphLoader
from torch_geometric_autoscale import get_data, compute_micro_f1
from LightningModules.GNN.utils import make_mlp

## Roadmap

1. Get simple memory test loaded (e.g. one event, GAT model applied)
2. Apply same model without autoscale! Compare
3. Look at memory scaling (i.e. how big can I make the GAT model??)
4. Consider model ~32 channels, model ~64 channels, random phi sampled model ~64 channels, autoscaled model ~64 channels
5. WANT to show true 64 >= autoscaled 64 >> random sampled 64 >> true 32
6. May need to pick the best hidden channels to prove this! i.e. true Nx32 >= autoscaled Nx32 >> random sampled Nx32 >> true 32

# Simple Memory Test

1. Load an embedded graph
2. Understand how to sample the graph
2. Define a GAS model and a non-GAS model (e.g. GAT)
3. Apply both!

## Dummy Data

1. Load in configs
2. Partition graph (i.e. there should only be one adjacency matrix)

### Testing Partition

In [3]:
num_nodes = int(1e5)
num_edges = int(1e6)
x = torch.rand(num_nodes, 3)
edge_index = torch.randint(x.shape[0], (2, num_edges))

In [4]:
adj = SparseTensor(
    row=edge_index[0], col=edge_index[1], sparse_sizes=(num_nodes, num_nodes)
)
adj_t = adj.t()
data = Data(x=x, edge_index=edge_index, adj_t=adj_t)

In [5]:
num_parts = 10
perm, ptr = metis(adj_t, num_parts=num_parts, log=True)

Computing METIS partitioning with 10 parts... Done! [4.36s]


`perm` is the list of nodes, rearranged to be contiguous in the cluster slice

In [6]:
perm

tensor([31466,  9792, 31517,  ..., 30058, 82892, 30029])

`ptr` is the list of slice edges

In [7]:
ptr

tensor([     0,  10002,  20000,  30003,  40005,  50001,  60002,  70004,  80004,
         90005, 100000])

We simply re-arrange data with the new node indices given in the `perm` lookup

In [8]:
data = permute(data, perm, log=True)

Permuting data... Done! [0.29s]


### Dataloader

In [9]:
loader = SubgraphLoader(data, ptr, batch_size=1, shuffle=True)

Pre-processing subgraphs... Done! [0.13s]


### Model

In [11]:
model_config = {
    "hidden_channels": 16,
    "hidden_heads": 1,
    "out_heads": 1,
    "num_layers": 3,
}

In [13]:
model = GAT(
    num_nodes=data.num_nodes,
    in_channels=x.shape[1],
    out_channels=1,
    device="cpu",
    **model_config
).to(device)

### Train

In [16]:
for batch, batch_size, n_id, _, _ in loader:
    print(batch, batch_size, n_id)

Data(adj_t=[10001, 62203, nnz=99742], x=[62203, 3]) 10001 tensor([50001, 50002, 50003,  ..., 28045, 41056, 74991])
Data(adj_t=[10000, 61945, nnz=99700], x=[61945, 3]) 10000 tensor([70004, 70005, 70006,  ..., 57197,  9335, 24587])
Data(adj_t=[9998, 61914, nnz=99846], x=[61914, 3]) 9998 tensor([10002, 10003, 10004,  ..., 22262, 56843, 96027])
Data(adj_t=[9995, 62105, nnz=100437], x=[62105, 3]) 9995 tensor([90005, 90006, 90007,  ..., 42591, 87234,  8122])
Data(adj_t=[10002, 61959, nnz=99672], x=[61959, 3]) 10002 tensor([    0,     1,     2,  ..., 26451, 33989, 69148])
Data(adj_t=[10002, 62303, nnz=100466], x=[62303, 3]) 10002 tensor([30003, 30004, 30005,  ...,  2041, 13943, 22381])
Data(adj_t=[10001, 62024, nnz=99442], x=[62024, 3]) 10001 tensor([80004, 80005, 80006,  ..., 20870, 32888, 94726])
Data(adj_t=[10003, 62295, nnz=100612], x=[62295, 3]) 10003 tensor([20000, 20001, 20002,  ...,  9214, 43816, 59682])
Data(adj_t=[10002, 62017, nnz=100169], x=[62017, 3]) 10002 tensor([60002, 60003, 

## Expanding to ITk Dataset

In [13]:
def convert_node_features(data):
    for key, item in data.to_dict().items():
        if type(item) is torch.Tensor and item.dim() == 1:
            data[key] = item.unsqueeze(1)

In [14]:
def load_datafile(event):
    data = torch.load(event, map_location="cpu")
    data.event_file = None
    data.y = None
    data.modulewise_true_edges = None
    data.signal_true_edges = None

    transformSparse = ToSparseTensor()
    data = transformSparse(data)
    convert_node_features(data)

    return data

In [15]:
input_dir = "/global/cfs/cdirs/m3443/data/ITk-upgrade/processed/embedding_processed/0.5GeV_barrel/train"
all_events = [os.path.join(input_dir, event) for event in os.listdir(input_dir)]

In [16]:
event = all_events[0]
data = torch.load(event, map_location="cpu")

In [25]:
%%time
num_events = 2
dataset = [load_datafile(event) for event in all_events[:num_events]]

CPU times: user 340 ms, sys: 19.4 ms, total: 359 ms
Wall time: 270 ms


In [26]:
%%time
data = Batch.from_data_list(dataset)

CPU times: user 58.8 ms, sys: 2.32 ms, total: 61.1 ms
Wall time: 46.3 ms


### Testing Partition

In [19]:
num_parts = 8 * num_events
perm, ptr = metis(data.adj_t, num_parts=num_parts, log=True)

Computing METIS partitioning with 16 parts... Done! [0.09s]


We simply re-arrange data with the new node indices given in the `perm` lookup

In [None]:
data = permute(data, perm, log=True)

### Dataloader

In [None]:
loader = SubgraphLoader(data, ptr, batch_size=1, shuffle=True)

### Model

In [None]:
model_config = {
    "hidden_channels": 128,
    "hidden_heads": 2,
    "out_heads": 1,
    "num_layers": 8,
}

In [None]:
model = GAT(
    num_nodes=data.num_nodes,
    in_channels=data.x.shape[1],
    out_channels=1,
    device="cpu",
    **model_config
).to(device)

### Train

In [None]:
optimizer = torch.optim.AdamW(model.parameters())

In [None]:
torch.cuda.reset_peak_memory_stats()

In [None]:
%%time
model.train()

for batch, batch_size, n_id, _, _ in loader:
    print(batch, batch_size, n_id)

    batch = batch.to(model.device)
    #     n_id = n_id.to(model.device) # This shouldn't be on device, since this is the (node ID <-> partion) map on the host

    optimizer.zero_grad()
    out = model(batch.x, batch.adj_t, batch_size, n_id)
    loss = out.sum()
    loss.backward()
    optimizer.step()

In [None]:
print(f"{torch.cuda.max_memory_allocated() / 1024**3}Gb allocated")

# Hacking for Tracking

## Phi Partition Function

In [61]:
num_parts = 4

new_ptr = [0]
new_perm = []

phi_segments = np.linspace(-1, 1, num_parts + 1)
phi = data.x[:, 1]

for batch in range(data.batch.max() + 1):

    for phi_segment_min, phi_segment_max in zip(phi_segments[:-1], phi_segments[1:]):

        batch_segment_mask = (
            (phi_segment_min < phi) & (phi < phi_segment_max) & (data.batch == batch)
        )

        segment_idx = torch.where(batch_segment_mask)[0]

        new_perm.append(segment_idx)
        new_ptr.append(len(segment_idx) + new_ptr[-1])

new_ptr = torch.Tensor(new_ptr).int()
new_perm = torch.cat(new_perm)

In [64]:
data

Batch(adj_t=[52533, 52533, nnz=1409928], batch=[52533], cell_data=[52533, 11], hid=[52533, 1], nhits=[52533, 1], pid=[52533, 1], primary=[52533, 1], pt=[52533, 1], ptr=[3], x=[52533, 3])

In [65]:
data = permute(data, new_perm, log=True)

Permuting data... Done! [0.31s]


In [66]:
loader = SubgraphLoader(data, new_ptr, batch_size=1, shuffle=True)

Pre-processing subgraphs... Done! [0.03s]


## Edge Classification Model

Regular ResAGNN

In [None]:
class VanillaResAGNN(GNNBase):
    def __init__(self, hparams):
        super().__init__(hparams)
        """
        Initialise the Lightning Module that can scan over different GNN training regimes
        """
        self.edge_network = make_mlp(
            (hparams["spatial_channels"] + hparams["cell_channels"] + hparams["hidden"])
            * 2,
            [hparams["spatial_channels"] + hparams["cell_channels"] + hparams["hidden"]]
            * hparams["nb_edge_layer"]
            + [1],
            hidden_activation=hparams["hidden_activation"],
            output_activation=None,
            layer_norm=hparams["layernorm"],
        )

        self.node_network = make_mlp(
            (hparams["spatial_channels"] + hparams["cell_channels"] + hparams["hidden"])
            * 2,
            [hparams["hidden"]] * hparams["nb_node_layer"],
            hidden_activation=hparams["hidden_activation"],
            output_activation=None,
            layer_norm=hparams["layernorm"],
        )

        self.input_network = make_mlp(
            hparams["spatial_channels"] + hparams["cell_channels"],
            [hparams["hidden"]] * hparams["nb_node_layer"],
            output_activation=hparams["hidden_activation"],
            layer_norm=hparams["layernorm"],
        )

    def forward(self, x, edge_index):
        start, end = edge_index
        input_x = x

        x = self.input_network(x)

        # Shortcut connect the inputs onto the hidden representation
        x = torch.cat([x, input_x], dim=-1)

        # Loop over iterations of edge and node networks
        for i in range(self.hparams["n_graph_iters"]):
            x_inital = x

            # Apply edge network
            edge_inputs = torch.cat([x[start], x[end]], dim=1)
            e = torch.sigmoid(self.edge_network(edge_inputs))

            # Apply node network
            messages = scatter_add(
                e * x[start], end, dim=0, dim_size=x.shape[0]
            ) + scatter_add(e * x[end], start, dim=0, dim_size=x.shape[0])
            node_inputs = torch.cat([messages, x], dim=1)
            x = self.node_network(node_inputs)

            # Shortcut connect the inputs onto the hidden representation
            x = torch.cat([x, input_x], dim=-1)

            # Residual connection
            x = x_inital + x

        edge_inputs = torch.cat([x[start], x[end]], dim=1)
        return self.edge_network(edge_inputs)

In [68]:
class ScalableAGNN(ScalableGNN):
    def __init__(self, num_nodes, hparams, pool_size = None, buffer_size = None, device = None):
        super().__init__(num_nodes, hparams["hidden"], hparams["n_graph_iters"], pool_size, buffer_size, device)
        """
        Initialise the Lightning Module that can scan over different GNN training regimes
        """
        self.edge_network = make_mlp(
            (hparams["spatial_channels"] + hparams["cell_channels"] + hparams["hidden"]) * 2,
            [hparams["spatial_channels"] + hparams["cell_channels"] + hparams["hidden"]] * hparams["nb_edge_layer"] + [1],
            hidden_activation=hparams["hidden_activation"],
            output_activation=None,
            layer_norm=hparams["layernorm"],
        )
        
        self.node_network = make_mlp(
            (hparams["spatial_channels"] + hparams["cell_channels"] + hparams["hidden"]) * 2,
            [hparams["hidden"]] * hparams["nb_node_layer"],
            hidden_activation=hparams["hidden_activation"],
            output_activation=None,
            layer_norm=hparams["layernorm"],
        )
        
        self.input_network = make_mlp(
            hparams["spatial_channels"] + hparams["cell_channels"], 
            [hparams["hidden"]]*hparams["nb_node_layer"],
            output_activation=hparams["hidden_activation"],
            layer_norm=hparams["layernorm"]
        )
            
    def message_passing(self, x, )
    
    def forward(self, x, edge_index):
        start, end = edge_index
        input_x = x

        x = self.input_network(x)

        # Shortcut connect the inputs onto the hidden representation
        x = torch.cat([x, input_x], dim=-1)

        # Loop over iterations of edge and node networks
        for i in range(self.hparams["n_graph_iters"]):
            x_inital = x

            # Apply edge network
            edge_inputs = torch.cat([x[start], x[end]], dim=1)
            e = torch.sigmoid(self.edge_network(edge_inputs))

            # Apply node network
            messages = (
                scatter_add(e * x[start], end, dim=0, dim_size=x.shape[0]) 
                + scatter_add(e * x[end], start, dim=0, dim_size=x.shape[0])
            )
            node_inputs = torch.cat([messages, x], dim=1)
            x = self.node_network(node_inputs)

            # Shortcut connect the inputs onto the hidden representation
            x = torch.cat([x, input_x], dim=-1)

            # Residual connection
            x = x_inital + x
    
        edge_inputs = torch.cat([x[start], x[end]], dim=1)
        return self.edge_network(edge_inputs)

### Model

In [75]:
with open("example_gnn.yaml") as f:
    model_config = yaml.load(f, Loader=yaml.FullLoader)

In [76]:
model = ScalableAGNN(num_nodes=data.num_nodes, hparams=model_config).to(device)