# Testing Construction of Triplets from GNN output

In [1]:
%load_ext autoreload
%autoreload 2

# System imports
import os
import sys
import yaml

# External imports
import matplotlib.pyplot as plt
import scipy as sp
from sklearn.decomposition import PCA
from sklearn.metrics import auc
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
import scipy.sparse.csgraph as scigraph
import scipy.sparse as sp
import cupy as cp

import wandb

import warnings
warnings.filterwarnings('ignore')
sys.path.append('../../../')
device = "cuda" if torch.cuda.is_available() else "cpu"

## Roadmap

1. Load in a good 0.5GeV model
2. Run some initial statistics (eff, pur)
3. Segment builder
4. Get some segment statistics (tracking eff, pur)

## Infrastructure

- GraphScore(Model, graph) --> graph (with scores)
- SegmentBuild(graph) --> labelled graph


## Load in Model

In [2]:
from LightningModules.GNN.Models.interaction_gnn import InteractionGNN
from LightningModules.GNN.Models.checkpoint_pyramid import CheckpointedPyramid

In [3]:
checkpoint_path = "/global/cfs/cdirs/m3443/data/lightning_models/lightning_checkpoints/ITk_1GeVSignal_Barrel_GNN/11eo3iqk/checkpoints/epoch=21-step=10449.ckpt"
checkpoint = torch.load(checkpoint_path)

model = InteractionGNN.load_from_checkpoint(checkpoint_path).to(device)

### Load in Data

In [4]:
model._hparams["datatype_split"]=[50, 50, 10]

In [5]:
model.setup(stage="fit")

Setting up dataset
Loading events
Events loaded!
Events processed!
Loading events
Events loaded!
Events processed!
Loading events
Events loaded!
Events processed!


In [None]:
model.valset

## Load/Score Graphs

### Load Pre-scored

In [17]:
graph_dir = "/global/cfs/cdirs/m3443/data/ITk-upgrade/processed/gnn_processed/0.5GeV_barrel_y"
all_files = [os.path.join(segment_dir, file) for file in os.listdir(graph_dir)] 

In [33]:
all_files

['/global/cfs/cdirs/m3443/data/ITk-upgrade/processed/gnn_processed/0.5GeV_barrel_y/0200.npz',
 '/global/cfs/cdirs/m3443/data/ITk-upgrade/processed/gnn_processed/0.5GeV_barrel_y/0178.npz',
 '/global/cfs/cdirs/m3443/data/ITk-upgrade/processed/gnn_processed/0.5GeV_barrel_y/0069.npz',
 '/global/cfs/cdirs/m3443/data/ITk-upgrade/processed/gnn_processed/0.5GeV_barrel_y/0101.npz',
 '/global/cfs/cdirs/m3443/data/ITk-upgrade/processed/gnn_processed/0.5GeV_barrel_y/0188.npz',
 '/global/cfs/cdirs/m3443/data/ITk-upgrade/processed/gnn_processed/0.5GeV_barrel_y/0079.npz',
 '/global/cfs/cdirs/m3443/data/ITk-upgrade/processed/gnn_processed/0.5GeV_barrel_y/0111.npz',
 '/global/cfs/cdirs/m3443/data/ITk-upgrade/processed/gnn_processed/0.5GeV_barrel_y/0002.npz',
 '/global/cfs/cdirs/m3443/data/ITk-upgrade/processed/gnn_processed/0.5GeV_barrel_y/0089.npz',
 '/global/cfs/cdirs/m3443/data/ITk-upgrade/processed/gnn_processed/0.5GeV_barrel_y/0121.npz',
 '/global/cfs/cdirs/m3443/data/ITk-upgrade/processed/gnn_pro

In [18]:
graph = np.load(all_files[0])

In [23]:
edge_list = np.stack([graph["senders"], graph["receivers"]])

### Score

In [6]:
def graphScore(model, events=None, output_dir=None):
    if events is None:
        events = model.valset
    
    # TODO: Run model over graphs and attach scores
    # If output_dir, then save to output_dir, else return events
    with torch.no_grad():
        for event in model.valset:
            batch = event.to(device)
            
            input_data = model.get_input_data(batch)
            input_edges = torch.cat([batch.edge_index, batch.edge_index.flip(0)], dim=-1)
            scores = torch.sigmoid(model(input_data, input_edges))[:event.edge_index.shape[1]]
            
            event.scores = scores

In [7]:
%%time
graphScore(model)

CPU times: user 3.21 s, sys: 29.6 ms, total: 3.24 s
Wall time: 3.31 s


## Load/Build Triplet

First we cut the edge_index based on an edge_cut. Then convert the passing edges into a sparse representation. Then... 

### Simple Test

In [8]:
passing_edges = torch.tensor([[0, 1, 1, 3], [1, 2, 3, 4]])
passing_edges = torch.cat([passing_edges, passing_edges.flip(0)], dim=-1)

In [9]:
passing_edges

tensor([[0, 1, 1, 3, 1, 2, 3, 4],
        [1, 2, 3, 4, 0, 1, 1, 3]])

In [10]:
%%time
passing_edges_cp = cp.asarray(passing_edges).astype('float32')
num_edges = passing_edges.shape[1]

CPU times: user 150 ms, sys: 27.1 ms, total: 177 ms
Wall time: 619 ms


In [11]:
%%time
e_ones = cp.array([1]*num_edges).astype('float32')
e_arange = cp.arange(num_edges).astype('float32')
e_max = passing_edges.max().item()

CPU times: user 3.48 ms, sys: 0 ns, total: 3.48 ms
Wall time: 8.01 ms


In [12]:
%%time
passing_edges_csr_in = cp.sparse.coo_matrix((e_ones, (passing_edges_cp[0], e_arange)), shape=(e_max+1, num_edges)).tocsr()
passing_edges_csr_out = cp.sparse.coo_matrix((e_ones, (passing_edges_cp[1], e_arange)), shape=(e_max+1, num_edges)).tocsr()

CPU times: user 222 ms, sys: 96.4 ms, total: 318 ms
Wall time: 1.04 s


In [13]:
%%time
triplet_edges = passing_edges_csr_out.T * passing_edges_csr_in

CPU times: user 5.21 ms, sys: 560 µs, total: 5.77 ms
Wall time: 7.45 ms


In [14]:
triplet_edges = triplet_edges.tocoo()

In [15]:
triplet_edges = torch.as_tensor(cp.stack([triplet_edges.row, triplet_edges.col]), device=device)

In [16]:
directed_map = torch.cat([torch.arange(num_edges/2), torch.arange(num_edges/2)]).int()

In [17]:
directed_map[triplet_edges.long()].shape

torch.Size([2, 16])

In [18]:
directed_triplet_edges = directed_map[triplet_edges.long()]

In [19]:
trimmed_triplet_edges = directed_triplet_edges[:, directed_triplet_edges[0] != directed_triplet_edges[1]]

In [20]:
trimmed_triplet_edges

tensor([[0, 0, 2, 1, 1, 2, 2, 3],
        [1, 2, 3, 2, 0, 1, 0, 2]], dtype=torch.int32)

In [21]:
trimmed_triplet_edges[:, trimmed_triplet_edges[0] < trimmed_triplet_edges[1]]

tensor([[0, 0, 2, 1],
        [1, 2, 3, 2]], dtype=torch.int32)

### Build Method

In [23]:
def buildTriplets(graph, edge_cut=0.5, directed=True):
    
    undir_graph = torch.cat([graph.edge_index, graph.edge_index.flip(0)], dim=-1)
    
     # apply cut
    passing_edges = undir_graph[:, graph.scores.repeat(2) > edge_cut]
    passing_y_truth = graph.y[graph.scores > edge_cut]
    passing_y_pid_truth = graph.y_pid[graph.scores > edge_cut]
    
    print("Eff:", graph.y[graph.scores > edge_cut].sum() / graph.y.sum(), "Pur:", graph.y[graph.scores > edge_cut].sum() / (graph.scores > edge_cut).sum())
    
    # convert to cupy
    passing_edges_cp = cp.asarray(passing_edges).astype('float32')
    
    # make some utility objects
    num_edges = passing_edges.shape[1]
    e_ones = cp.array([1]*num_edges).astype('float32')
    e_arange = cp.arange(num_edges).astype('float32')
    e_max = passing_edges.max().item()
    
    # build sparse edge array
    passing_edges_csr_in = cp.sparse.coo_matrix((e_ones, (passing_edges_cp[0], e_arange)), shape=(e_max+1, num_edges)).tocsr()
    passing_edges_csr_out = cp.sparse.coo_matrix((e_ones, (passing_edges_cp[1], e_arange)), shape=(e_max+1, num_edges)).tocsr()
    
    # convert to triplets
    triplet_edges = passing_edges_csr_out.T * passing_edges_csr_in
    triplet_edges = triplet_edges.tocoo()
    
    # convert back to pytorch
    undirected_triplet_edges = torch.as_tensor(cp.stack([triplet_edges.row, triplet_edges.col]), device=device)
    
    # convert back to a single-direction edge list
    if directed:
        directed_map = torch.cat([torch.arange(num_edges/2), torch.arange(num_edges/2)]).int()
        directed_triplet_edges = directed_map[undirected_triplet_edges.long()].long()
        directed_triplet_edges = directed_triplet_edges[:, directed_triplet_edges[0] != directed_triplet_edges[1]] # Remove self-loops
        directed_triplet_edges = directed_triplet_edges[:, directed_triplet_edges[0] < directed_triplet_edges[1]] # Remove duplicate edges
        
        return directed_triplet_edges, passing_y_truth[directed_triplet_edges].all(0), passing_y_pid_truth[directed_triplet_edges].all(0)
    
    else:
        return undirected_triplet_edges, passing_y_truth[undirected_triplet_edges].all(0), passing_y_pid_truth[undirected_triplet_edges].all(0)

In [25]:
graph = model.valset[0]

In [42]:
%%time
triplet_edges, triplet_y_truth, triplet_y_pid_truth = buildTriplets(graph, edge_cut=0.2, directed=True)

Eff: tensor(0.9568, device='cuda:0') Pur: tensor(0.1516, device='cuda:0')
CPU times: user 51.4 ms, sys: 924 µs, total: 52.4 ms
Wall time: 38 ms


In [43]:
triplet_edges.shape

torch.Size([2, 107308])

In [44]:
graph.x

tensor([[ 0.0422,  0.0689, -0.2260],
        [ 0.0398,  0.0228, -0.2445],
        [ 0.0383, -0.0310, -0.2346],
        ...,
        [ 1.0001, -0.0086,  1.1771],
        [ 0.9999, -0.0090,  1.2255],
        [ 0.9995, -0.0095,  1.2255]], device='cuda:0')

In [45]:
triplet_edges

tensor([[    0,     0,     0,  ..., 35451, 35453, 35456],
        [    1,     2, 17342,  ..., 35454, 35454, 35457]])

In [56]:
x = graph.x
c = graph.cell_data
e = graph.edge_index

In [47]:
x_mean = (x[e[0]] + x[e[1]])/2

In [49]:
x_diff = (x[e[0]] - x[e[1]]).abs()

In [50]:
x_diff

tensor([[1.0668e-01, 1.0719e-02, 6.3056e-01],
        [1.4100e-01, 5.1535e-03, 8.0265e-01],
        [1.3625e-01, 9.0008e-03, 8.0471e-01],
        ...,
        [1.5900e-02, 1.1801e-03, 3.6003e-02],
        [1.5670e-02, 1.3558e-03, 4.7430e-02],
        [1.6271e-02, 7.1364e-04, 2.1740e-03]], device='cuda:0')

In [51]:
triplet_x = torch.cat([x_mean, x_diff], dim=-1)

In [53]:
triplet_x.shape

torch.Size([310705, 6])