# Self-Pruning Graph Neural Network

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# System imports
import os
import sys
import yaml

# External imports
import matplotlib.pyplot as plt
import scipy as sp
from sklearn.decomposition import PCA
from sklearn.metrics import auc
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger

sys.path.append("..")
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
with open("selfpruning_gnn.yaml") as f:
    hparams = yaml.load(f, Loader=yaml.FullLoader)

### Dataset

In [8]:
%%time
model.setup(stage="fit")

CPU times: user 19.9 s, sys: 2.22 s, total: 22.2 s
Wall time: 16.4 s


In [9]:
sample = model.trainset[0]

In [63]:
sample.y.sum() / sample.modulewise_true_edges.shape[1]

tensor(0.9523)

In [64]:
edges = sample.edge_index

In [65]:
pid = sample.pid

In [66]:
pid[pid > 0].shape

torch.Size([154102])

In [67]:
pid.shape

torch.Size([341982])

In [68]:
edges.shape

torch.Size([2, 457686])

In [69]:
(sample.pid[edges[0]] == sample.pid[edges[1]]).sum()

tensor(136479)

### Memory Test

In [8]:
%%time
model.setup(stage="fit")

CPU times: user 26.2 s, sys: 2.62 s, total: 28.8 s
Wall time: 18.6 s


In [10]:
sample = model.trainset[0].to(device)

In [11]:
model = model.to(device)

In [12]:
torch.cuda.reset_peak_memory_stats()
y = sample.pid[sample.edge_index[0]] == sample.pid[sample.edge_index[1]]
output = model(sample.x.to(device), sample.edge_index.to(device), y.to(device))

In [13]:
print(torch.cuda.max_memory_allocated() / 1024**3, "Gb")

1.9535369873046875 Gb


### Train GNN

In [6]:
logger = WandbLogger(project="ITk_1GeV_GNN", group="InitialTest")
trainer = Trainer(gpus=1, max_epochs=10, logger=logger)
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[34m[1mwandb[0m: Currently logged in as: [33mmurnanedaniel[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.0 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Set SLURM handle signals.

  | Name          | Type        | Params
----------------------------------------------
0 | input_network | Sequential  | 2.4 K 
1 | edge_network  | EdgeNetwork | 4.4 K 
2 | node_network  | NodeNetwork | 4.3 K 
----------------------------------------------
11.2 K    Trainable params
0         Non-trainable params
11.2 K    Total params
0.045     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
  eff = torch.tensor(edge_true_positive / edge_true)
  pur = torch.tensor(edge_true_positive / edge_positive)
  f'The dataloader, {name}, does not have many workers which may be a bottleneck.'


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...')


KeyboardInterrupt: 

## Self-pruning

In [4]:
from LightningModules.GNN.utils import load_dataset, random_edge_slice_v2, make_mlp
from LightningModules.GNN.Models.agnn import EdgeNetwork, NodeNetwork
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from torch.utils.checkpoint import checkpoint

In [5]:
class GNNBase(LightningModule):
    def __init__(self, hparams):
        super().__init__()
        """
        Initialise the Lightning Module that can scan over different GNN training regimes
        """
        # Assign hyperparameters
        self.save_hyperparameters(hparams)
        self.hparams["posted_alert"] = False

    def setup(self, stage):
        # Handle any subset of [train, val, test] data split, assuming that ordering
        input_dirs = [None, None, None]
        input_dirs[: len(self.hparams["datatype_names"])] = [
            os.path.join(self.hparams["input_dir"], datatype)
            for datatype in self.hparams["datatype_names"]
        ]
        self.trainset, self.valset, self.testset = [
            load_dataset(
                input_dir,
                self.hparams["datatype_split"][i],
                self.hparams["pt_min"],
                self.hparams["noise"],
            )
            for i, input_dir in enumerate(input_dirs)
        ]

    def train_dataloader(self):
        if self.trainset is not None:
            return DataLoader(self.trainset, batch_size=1, num_workers=1)
        else:
            return None

    def val_dataloader(self):
        if self.valset is not None:
            return DataLoader(self.valset, batch_size=1, num_workers=1)
        else:
            return None

    def test_dataloader(self):
        if self.testset is not None:
            return DataLoader(self.testset, batch_size=1, num_workers=1)
        else:
            return None

    def configure_optimizers(self):
        optimizer = [
            torch.optim.AdamW(
                self.parameters(),
                lr=(self.hparams["lr"]),
                betas=(0.9, 0.999),
                eps=1e-08,
                amsgrad=True,
            )
        ]
        scheduler = [
            {
                "scheduler": torch.optim.lr_scheduler.StepLR(
                    optimizer[0],
                    step_size=self.hparams["patience"],
                    gamma=self.hparams["factor"],
                ),
                "interval": "epoch",
                "frequency": 1,
            }
        ]
        return optimizer, scheduler

    def training_step(self, batch, batch_idx):

        weight = (
            torch.tensor(self.hparams["weight"])
            if ("weight" in self.hparams)
            else torch.tensor((~batch.y_pid.bool()).sum() / batch.y_pid.sum())
        )

        if "pid" in self.hparams["regime"]:
            y = batch.pid[batch.edge_index[0]] == batch.pid[batch.edge_index[1]]
        else:
            y = batch.y

        output = (
            self(
                torch.cat([batch.cell_data, batch.x], axis=-1), batch.edge_index
            ).squeeze()
            if ("ci" in self.hparams["regime"])
            else self(batch.x, batch.edge_index, y)
        )

        output = torch.cat(output)

        if "weighting" in self.hparams["regime"]:
            manual_weights = batch.weights
        else:
            manual_weights = None

        #         print(output.shape, torch.repeat_interleave(y.float(), self.hparams["n_graph_iters"]+1).shape)

        loss = F.binary_cross_entropy_with_logits(
            output,
            y.float().repeat(self.hparams["n_graph_iters"]),
            weight=torch.repeat_interleave(
                (
                    (torch.arange(self.hparams["n_graph_iters"]) + 1)
                    / self.hparams["n_graph_iters"]
                ).to(self.device),
                len(y),
            ),
            pos_weight=weight,
        )

        self.log("train_loss", loss)

        return loss

    def shared_evaluation(self, batch, batch_idx):

        weight = (
            torch.tensor(self.hparams["weight"])
            if ("weight" in self.hparams)
            else torch.tensor((~batch.y_pid.bool()).sum() / batch.y_pid.sum())
        )

        if "pid" in self.hparams["regime"]:
            y = batch.pid[batch.edge_index[0]] == batch.pid[batch.edge_index[1]]
        else:
            y = batch.y

        output = (
            self(
                torch.cat([batch.cell_data, batch.x], axis=-1), batch.edge_index
            ).squeeze()
            if ("ci" in self.hparams["regime"])
            else self(batch.x, batch.edge_index, y)
        )

        output = output[-1]

        truth = (
            (batch.pid[batch.edge_index[0]] == batch.pid[batch.edge_index[1]]).float()
            if "pid" in self.hparams["regime"]
            else batch.y
        )

        if "weighting" in self.hparams["regime"]:
            manual_weights = batch.weights
        else:
            manual_weights = None

        loss = F.binary_cross_entropy_with_logits(
            output, truth.float(), weight=manual_weights, pos_weight=weight
        )

        # Edge filter performance
        preds = F.sigmoid(output) > self.hparams["edge_cut"]
        edge_positive = preds.sum().float()

        edge_true = truth.sum().float()
        edge_true_positive = (truth.bool() & preds).sum().float()

        eff = torch.tensor(edge_true_positive / edge_true)
        pur = torch.tensor(edge_true_positive / edge_positive)

        current_lr = self.optimizers().param_groups[0]["lr"]
        self.log_dict(
            {"val_loss": loss, "eff": eff, "pur": pur, "current_lr": current_lr}
        )

        return {
            "loss": loss,
            "preds": preds.cpu().numpy(),
            "truth": truth.cpu().numpy(),
        }

    #         return {"loss": loss, "preds": preds, "truth": truth}

    def validation_step(self, batch, batch_idx):

        outputs = self.shared_evaluation(batch, batch_idx)

        return outputs["loss"]

    def test_step(self, batch, batch_idx):

        outputs = self.shared_evaluation(batch, batch_idx)

        return outputs

    def test_step_end(self, output_results):

        print("Step:", output_results)

    def test_epoch_end(self, outputs):

        print("Epoch:", outputs)

    def optimizer_step(
        self,
        epoch,
        batch_idx,
        optimizer,
        optimizer_idx,
        optimizer_closure=None,
        on_tpu=False,
        using_native_amp=False,
        using_lbfgs=False,
    ):
        # warm up lr
        if (self.hparams["warmup"] is not None) and (
            self.trainer.global_step < self.hparams["warmup"]
        ):
            lr_scale = min(
                1.0, float(self.trainer.global_step + 1) / self.hparams["warmup"]
            )
            for pg in optimizer.param_groups:
                pg["lr"] = lr_scale * self.hparams["lr"]

        # update params
        optimizer.step(closure=optimizer_closure)
        optimizer.zero_grad()

In [6]:
class PrunedAGNN(GNNBase):
    def __init__(self, hparams):
        super().__init__(hparams)
        """
        Initialise the Lightning Module that can scan over different GNN training regimes
        """

        # Setup input network
        self.input_network = make_mlp(
            hparams["in_channels"],
            [hparams["hidden"]] * hparams["nb_node_layer"],
            output_activation=hparams["hidden_activation"],
            layer_norm=hparams["layernorm"],
        )

        # Setup the edge network
        self.edge_network = EdgeNetwork(
            hparams["hidden"],
            hparams["hidden"],
            hparams["nb_edge_layer"],
            hparams["hidden_activation"],
            hparams["layernorm"],
        )

        self.output_network = [
            EdgeNetwork(
                hparams["hidden"],
                hparams["hidden"],
                hparams["nb_edge_layer"],
                hparams["hidden_activation"],
                hparams["layernorm"],
            ).to(device)
            for i in range(hparams["n_graph_iters"])
        ]

        # Setup the node layers
        self.node_network = NodeNetwork(
            hparams["hidden"],
            hparams["hidden"],
            hparams["nb_node_layer"],
            hparams["hidden_activation"],
            hparams["layernorm"],
        )

    def forward(self, x, edge_index, y):
        input_x = x

        x = self.input_network(x)

        output_list = []

        # Loop over iterations of edge and node networks
        for i in range(self.hparams["n_graph_iters"]):
            x_inital = x

            # Apply edge network

            e_attention = torch.sigmoid(checkpoint(self.edge_network, x, edge_index))

            # Apply node network
            x = checkpoint(self.node_network, x, e_attention, edge_index)

            # Residual connection
            x = x_inital + x

            e_output = checkpoint(self.output_network[i], x, edge_index)
            output_list.append(e_output)

            precut_mask = torch.sigmoid(e_output) > self.hparams["precut"]

            self.log_dict(
                {
                    f"precut_eff_{i}": (precut_mask & y).sum() / y.sum(),
                    f"precut_pur_{i}": (precut_mask & y).sum() / precut_mask.sum(),
                }
            )

        #         e_output = self.output_network[-1](x, edge_index)
        #         output_list.append(e_output)

        return output_list

In [6]:
class MultistepAGNN(GNNBase):
    def __init__(self, hparams):
        super().__init__(hparams)
        """
        Initialise the Lightning Module that can scan over different GNN training regimes
        """

        # Setup input network
        self.input_network = make_mlp(
            hparams["in_channels"],
            [hparams["hidden"]] * hparams["nb_node_layer"],
            output_activation=hparams["hidden_activation"],
            layer_norm=hparams["layernorm"],
        )

        # Setup the edge network
        self.edge_network = EdgeNetwork(
            hparams["hidden"],
            hparams["hidden"],
            hparams["nb_edge_layer"],
            hparams["hidden_activation"],
            hparams["layernorm"],
        )

        self.output_network = EdgeNetwork(
            hparams["hidden"],
            hparams["hidden"],
            hparams["nb_edge_layer"],
            hparams["hidden_activation"],
            hparams["layernorm"],
        )

        # Setup the node layers
        self.node_network = NodeNetwork(
            hparams["hidden"],
            hparams["hidden"],
            hparams["nb_node_layer"],
            hparams["hidden_activation"],
            hparams["layernorm"],
        )

    def forward(self, x, edge_index, y):
        input_x = x

        x = self.input_network(x)

        output_list = []

        # Loop over iterations of edge and node networks
        for i in range(self.hparams["n_graph_iters"]):
            x_inital = x

            # Apply edge network

            e_attention = torch.sigmoid(checkpoint(self.edge_network, x, edge_index))

            # Apply node network
            x = checkpoint(self.node_network, x, e_attention, edge_index)

            # Residual connection
            x = x_inital + x

            e_output = checkpoint(self.output_network, x, edge_index)
            output_list.append(e_output)

            precut_mask = torch.sigmoid(e_output) > self.hparams["precut"]

            self.log_dict(
                {
                    f"precut_eff_{i}": (precut_mask & y).sum() / y.sum(),
                    f"precut_pur_{i}": (precut_mask & y).sum() / precut_mask.sum(),
                }
            )

        #         e_output = self.output_network[-1](x, edge_index)
        #         output_list.append(e_output)

        return output_list

In [7]:
model = MultistepAGNN(hparams)

In [8]:
logger = WandbLogger(project="Selfpruning_ITk_1GeV_GNN", group="RecurrentMultistep")
trainer = Trainer(gpus=1, max_epochs=20, logger=logger)
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[34m[1mwandb[0m: Currently logged in as: [33mmurnanedaniel[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.0 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Set SLURM handle signals.

  | Name           | Type        | Params
-----------------------------------------------
0 | input_network  | Sequential  | 9.0 K 
1 | edge_network   | EdgeNetwork | 17.0 K
2 | output_network | EdgeNetwork | 17.0 K
3 | node_network   | NodeNetwork | 16.8 K
-----------------------------------------------
59.8 K    Trainable params
0         Non-trainable params
59.8 K    Total params
0.239     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
  f'The dataloader, {name}, does not have many workers which may be a bottleneck.'


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  "Relying on `self.log('val_loss', ...)` to set the ModelCheckpoint monitor is deprecated in v1.2"


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...')
