# Segment Stitching

In [1]:
%load_ext autoreload
%autoreload 2

# System imports
import os
import sys
import yaml

# External imports
import matplotlib.pyplot as plt
import scipy as sp
from sklearn.decomposition import PCA
from sklearn.metrics import auc
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
import scipy.sparse.csgraph as scigraph
import scipy.sparse as sp
from torch_scatter import scatter
sys.path.append('../../../')

import wandb
from LightningModules.Segmenting.utils.segmentation_utils import labelSegments, sparse_score_segments

import warnings
warnings.filterwarnings('ignore')

device = "cuda" if torch.cuda.is_available() else "cpu"

## Roadmap

1. Load in a good 0.5GeV model
2. Run some initial statistics (eff, pur)
3. Segment builder
4. Get some segment statistics (tracking eff, pur)

## Infrastructure

- GraphScore(Model, graph) --> graph (with scores)
- SegmentBuild(graph) --> labelled graph


## Load in Model

In [2]:
from LightningModules.Segmenting.Models.checkpoint_pyramid import CheckpointedPyramid
from LightningModules.Segmenting.Models.interaction_gnn import InteractionGNN

In [3]:
with open("../configs/default_Segment.yaml") as f:
        hparams = yaml.load(f, Loader=yaml.FullLoader)

In [4]:
model = CheckpointedPyramid(hparams)

### Load

In [3]:
checkpoint_dir = "/global/cfs/cdirs/m3443/data/lightning_models/lightning_checkpoints/ITk_Stitcher_Testing/qzhsnlc5/checkpoints/epoch=38-step=19499.ckpt"

In [4]:
model = CheckpointedPyramid.load_from_checkpoint(checkpoint_dir)

### Load in Data

In [5]:
model._hparams["datatype_split"]=[50, 50, 10]

In [6]:
model.setup(stage="fit")

Setting up dataset
Loading events
Events loaded!
Events processed!
Loading events
Events loaded!
Events processed!
Loading events
Events loaded!
Events processed!


In [8]:
graph = model.trainset[0]

In [15]:
(graph.pid_pairs[0] == graph.pid_pairs[1]).sum() / graph.pid_pairs.shape[1]

tensor(0.0002)

## Load/Build Segments

In [137]:
graph = model.valset[0]

In [138]:
cut = 0.8

positive = graph.scores > cut
tp = positive & graph.y.bool()
labelGraph(graph, edge_cut=cut)
print("Pur:", tp.sum() / positive.sum(), "Eff:", tp.sum() / graph.y.sum())

AttributeError: 'Data' object has no attribute 'scores'

## Input Graph

Examine the truth graph first:

Let's see what score the input graph could get from ground truth

In [22]:
labels = labelSegments(graph.edge_index[:, graph.y.bool()], graph)

In [51]:
def sparse_score_segments(labels, pids, signal_pids):
    
    unique_pids, new_pids = pids.unique(return_inverse=True)
    signal_segments_pids, unique_signal_segments_pids = get_unique_signal_segments(labels, new_pids, signal_pids)
    
    
    iou, segment_count, pid_count = get_jaccard_matrix(labels, new_pids, signal_segments_pids, unique_signal_segments_pids)
    
    sparse_segment_count = sp.coo_matrix((segment_count[unique_signal_segments_pids[0]].cpu(), unique_signal_segments_pids.cpu().numpy())).tocsr()
    sparse_pid_count = sp.coo_matrix((pid_count[unique_signal_segments_pids[1]].cpu(), unique_signal_segments_pids.cpu().numpy())).tocsr()
    
    segment_pur = iou.multiply(sparse_segment_count).sum() / segment_count[unique_signal_segments_pids[0]].sum()
    segment_eff = iou.multiply(sparse_pid_count).sum() / pid_count[unique_signal_segments_pids[1]].sum()
    
    segment_f1 = 2 * segment_pur * segment_eff / (segment_pur + segment_eff)
    
    return segment_pur, segment_eff, segment_f1

def get_jaccard_matrix(labels, pids, signal_segments_pids, unique_signal_segments_pids):
    
    sparse_intersection = sp.coo_matrix((np.ones(signal_segments_pids.shape[1]), signal_segments_pids.cpu().numpy())).tocsr()
    
    segment_count = labels.unique(return_counts=True)[1]
    pid_count = pids.unique(return_counts=True)[1]
    
    union_counts = segment_count[unique_signal_segments_pids[0]] + pid_count[unique_signal_segments_pids[1]]
    sparse_sum = sp.coo_matrix((union_counts.cpu(), unique_signal_segments_pids.cpu().numpy())).tocsr()
    sparse_union = sparse_sum - sparse_intersection
    sparse_union.data = 1 / sparse_union.data
    iou = sparse_intersection.multiply(sparse_union)
    
    return iou, segment_count, pid_count

def get_unique_signal_segments(labels, pids, signal_pids):
    
    labels_unique, labels_inverse, labels_counts = labels.unique(return_counts=True, return_inverse=True)
    
    segments_pids = torch.stack([labels, pids])
    is_signal = torch.isin(pids, pids[signal_pids]) & (labels_counts[labels_inverse] >= 3)
    
    signal_segments_pids = segments_pids[:, is_signal]
    unique_signal_segments_pids = signal_segments_pids.unique(dim=1)
    
    return signal_segments_pids, unique_signal_segments_pids

In [35]:
sparse_score_segments(graph.labels, graph.pid, graph.signal_true_edges)

(tensor(0.5578, device='cuda:0'),
 tensor(0.5798, device='cuda:0'),
 tensor(0.5686, device='cuda:0'))

## Label All Events

In [26]:
%%time

for event in model.valset:
    event.labels = labelSegments(event.edge_index[:, event.y.bool()], event)  

for event in model.trainset:
    event.labels = labelSegments(event.edge_index[:, event.y.bool()], event)  

CPU times: user 763 ms, sys: 28.3 ms, total: 791 ms
Wall time: 396 ms


In [37]:
graph = model.trainset[0]

In [30]:
sparse_score_segments(graph.labels, graph.pid, graph.signal_true_edges)

(tensor(0.9382), tensor(0.8978), tensor(0.9176))

## GNN Testing

In [7]:
graph = model.trainset[0].to(device)
graph

Data(cell_data=[96465, 11], edge_index=[2, 261315], event_file="/global/cfs/cdirs/m3443/data/ITk-upgrade/processed/full_events_v4/event000010001", hid=[96465], label_pairs=[2, 276396], labels=[96465], long_mask=[96465], modulewise_true_edges=[2, 37882], nhits=[96465], pid=[96465], pid_pairs=[2, 276396], primary=[96465], pt=[96465], signal_true_edges=[2, 5305], x=[96465, 3], y=[261315], y_pid=[261315])

In [10]:
with torch.no_grad():
    input_data = model.get_input_data(graph)
output = model.to(device)(input_data, graph.edge_index, graph.labels, graph.label_pairs)

In [13]:
graph.pid_pairs

tensor([[15350000114, 15350000114, 15350000114,  ..., 15430000390,
         15430000390,        1147],
        [15350000113, 15460000079, 16300000922,  ...,        1147,
                1358,        1358]], device='cuda:0')

In [11]:
output

tensor([[-0.4305],
        [-0.4687],
        [-0.4315],
        ...,
        [-0.2801],
        [-0.2987],
        [-0.2221]], device='cuda:0', grad_fn=<CheckpointFunctionBackward>)

# Training

In [5]:
logger = WandbLogger(project=hparams["project"], group="InitialTest", save_dir=hparams["artifacts"])
trainer = Trainer(gpus=1, max_epochs=hparams["max_epochs"], logger=logger)#, precision=16)
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


Setting up dataset
Loading events
Events loaded!
Events processed!
Loading events
Events loaded!
Events processed!
Loading events
Events loaded!


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Events processed!


[34m[1mwandb[0m: Currently logged in as: [33mmurnanedaniel[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.10 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type       | Params
-----------------------------------------------
0 | node_encoder    | Sequential | 135 K 
1 | edge_network    | Sequential | 173 K 
2 | node_network    | Sequential | 329 K 
3 | segment_network | Sequential | 198 K 
4 | output_network  | Sequential | 264 K 
-----------------------------------------------
1.1 M     Trainable params
0         Non-trainable params
1.1 M     Total params
4.407     Total estimated model params size (MB)


Epoch 0:  91%|█████████ | 500/550 [07:37<00:45,  1.09it/s, loss=0.0346, v_num=to9m]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/50 [00:00<?, ?it/s][A
Epoch 0:  91%|█████████▏| 502/550 [07:37<00:43,  1.10it/s, loss=0.0346, v_num=to9m]
Validating:   4%|▍         | 2/50 [00:00<00:13,  3.57it/s][A
Epoch 0:  92%|█████████▏| 504/550 [07:38<00:41,  1.10it/s, loss=0.0346, v_num=to9m]
Validating:   8%|▊         | 4/50 [00:01<00:12,  3.61it/s][A
Epoch 0:  92%|█████████▏| 506/550 [07:38<00:39,  1.10it/s, loss=0.0346, v_num=to9m]
Validating:  12%|█▏        | 6/50 [00:01<00:11,  3.93it/s][A
Epoch 0:  92%|█████████▏| 508/550 [07:39<00:37,  1.11it/s, loss=0.0346, v_num=to9m]
Validating:  16%|█▌        | 8/50 [00:02<00:10,  3.96it/s][A
Epoch 0:  93%|█████████▎| 510/550 [07:39<00:36,  1.11it/s, loss=0.0346, v_num=to9m]
Validating:  20%|██        | 10/50 [00:02<00:10,  3.83it/s][A
Epoch 0:  93%|█████████▎| 512/550 [07:40<00:34,  1.11it/s, loss=0.0346, v_num=to9m]
Validating:  

In [61]:
graph = model.valset[1].to(device)

In [62]:
with torch.no_grad():
    input_data = model.get_input_data(graph)
    input_graph = torch.cat([graph.edge_index, graph.edge_index.flip(0)], dim=-1)
    output = torch.sigmoid(model.to(device)(input_data, input_graph, graph.labels, graph.label_pairs))

In [63]:
graph

Data(cell_data=[95751, 11], edge_index=[2, 255367], event_file="/global/cfs/cdirs/m3443/data/ITk-upgrade/processed/full_events_v4/event000010010", hid=[95751], label_pairs=[2, 314028], labels=[95751], long_mask=[95751], modulewise_true_edges=[2, 38013], nhits=[95751], pid=[95751], pid_pairs=[2, 314028], primary=[95751], pt=[95751], signal_true_edges=[2, 5413], x=[95751, 3], y=[255367], y_pid=[255367])

In [78]:
preds = output.squeeze() > 0.95
truth = graph.pid_pairs[0] == graph.pid_pairs[1]

In [79]:
# Signal true & signal tp
true = truth.sum().float()
positive = preds.sum().float()
true_positive = (truth.bool() & preds).sum().float()   

# Eff, pur, auc
eff = true_positive / true
pur = true_positive / positive

print(eff, pur)

tensor(0.9600, device='cuda:0') tensor(0.9600, device='cuda:0')


In [89]:
model = model.to(device)

In [129]:
true_positive, true, positive = 0, 0, 0
unstitched_scores = []
stitched_scores = []

with torch.no_grad():
    for graph in model.valset:
        input_data = model.get_input_data(graph).to(device)
        input_graph = torch.cat([graph.edge_index, graph.edge_index.flip(0)], dim=-1).to(device)
        output = torch.sigmoid(model(input_data, input_graph, graph.labels.to(device), graph.label_pairs.to(device)))
        
        preds = output.squeeze() > 0.98
        truth = graph.pid_pairs[0] == graph.pid_pairs[1]
        
        true += truth.sum().float().cpu()
        positive += preds.sum().float().cpu()
        true_positive += (truth.bool().cpu() & preds.cpu()).sum().float()
        
        # Eff, pur, auc
        eff = true_positive / true
        pur = true_positive / positive

        print(eff, pur)
        
        labels = labelSegments(graph.edge_index[:, graph.y.bool()], len(graph.x))
        unstitched_scores.append(sparse_score_segments(labels, graph.pid, graph.signal_true_edges)[2].cpu())
        
        connected_segments = labelSegments(graph.label_pairs[:, preds], len(labels))
        relabelled = connected_segments[labels]
        stitched_scores.append(sparse_score_segments(relabelled, graph.pid, graph.signal_true_edges)[2].cpu())

tensor(0.7963) tensor(0.8776)
tensor(0.8462) tensor(0.9362)
tensor(0.8497) tensor(0.9091)
tensor(0.8390) tensor(0.8912)
tensor(0.8167) tensor(0.8991)
tensor(0.8111) tensor(0.8975)
tensor(0.8211) tensor(0.8862)
tensor(0.8045) tensor(0.8931)
tensor(0.8036) tensor(0.8949)
tensor(0.8076) tensor(0.8783)
tensor(0.8087) tensor(0.8742)
tensor(0.8058) tensor(0.8784)
tensor(0.8078) tensor(0.8774)
tensor(0.8110) tensor(0.8655)
tensor(0.8131) tensor(0.8629)
tensor(0.8042) tensor(0.8586)
tensor(0.8078) tensor(0.8591)
tensor(0.8029) tensor(0.8586)
tensor(0.7988) tensor(0.8584)
tensor(0.8026) tensor(0.8605)
tensor(0.8025) tensor(0.8600)
tensor(0.8029) tensor(0.8624)
tensor(0.8090) tensor(0.8598)
tensor(0.8071) tensor(0.8628)
tensor(0.8084) tensor(0.8620)
tensor(0.8081) tensor(0.8643)
tensor(0.8053) tensor(0.8607)
tensor(0.8018) tensor(0.8650)
tensor(0.8026) tensor(0.8611)
tensor(0.8015) tensor(0.8588)
tensor(0.8003) tensor(0.8596)
tensor(0.8005) tensor(0.8610)
tensor(0.7985) tensor(0.8613)
tensor(0.7

In [119]:
torch.mean(torch.stack(unstitched_scores)), torch.std(torch.stack(unstitched_scores))

(tensor(0.9230), tensor(0.0083))

In [120]:
torch.mean(torch.stack(stitched_scores)), torch.std(torch.stack(stitched_scores))

(tensor(0.9464), tensor(0.0074))

In [117]:
torch.mean(torch.stack(stitched_scores)), torch.std(torch.stack(stitched_scores))

(tensor(0.9498), tensor(0.0073))

### Test score improvement

In [123]:
labels = labelSegments(graph.edge_index[:, graph.y.bool()], len(graph.x))

In [124]:
sparse_score_segments(labels, graph.pid, graph.signal_true_edges)

(tensor(0.9495), tensor(0.9190), tensor(0.9340))

In [125]:
connected_segments = labelSegments(graph.label_pairs[:, preds], len(labels))
relabelled = connected_segments[labels]

In [126]:
sparse_score_segments(relabelled, graph.pid, graph.signal_true_edges)

(tensor(0.9652), tensor(0.9483), tensor(0.9567))

In [121]:
labels = labelSegments(graph.signal_true_edges, len(graph.x))

In [122]:
sparse_score_segments(labels, graph.pid, graph.signal_true_edges)

(tensor(1.), tensor(1.), tensor(1.))

In [11]:
graph.label_pairs.max()

tensor(89228, device='cuda:0')

In [64]:
graph.label_pairs[:, preds].T

tensor([[84855,   218],
        [33524,   245],
        [47413,   245],
        [  245, 47443],
        [  259, 33524],
        [47413,   259],
        [33673,   313],
        [  319, 33673],
        [33571,   324],
        [17454,   569],
        [17488,   569],
        [  569, 33350],
        [85390,   771],
        [  806, 84855],
        [85390,   806],
        [85390,   893],
        [84855,  1038],
        [85390,  1038],
        [62597,  1128],
        [48952,  1152],
        [ 1152, 49707],
        [35142,  1205],
        [49707,  1205],
        [35142,  1347],
        [62597,  1514],
        [75919,  1514],
        [ 1517, 62597],
        [ 1534, 49815],
        [ 1726, 76438],
        [ 1750, 86008],
        [49707,  1872],
        [64103,  2283],
        [ 2322, 35973],
        [ 2336, 86008],
        [ 2352, 64041],
        [ 2485, 20824],
        [64639,  2485],
        [ 2561, 64103],
        [ 2561, 64639],
        [ 2821, 64940],
        [77398,  2873],
        [ 3069, 

In [59]:
labels.shape

torch.Size([104365])

In [142]:
for new_label_pair in graph.label_pairs[:, preds].T:
    print(new_label_pair)
    print((labels == new_label_pair[1]).sum())
    
    labels[labels == new_label_pair[1]] = new_label_pair[0]

tensor([47413,   245], device='cuda:0')
tensor(4, device='cuda:0')
tensor([33524,   259], device='cuda:0')
tensor(3, device='cuda:0')
tensor([  313, 33673], device='cuda:0')
tensor(5, device='cuda:0')
tensor([  319, 33673], device='cuda:0')
tensor(0, device='cuda:0')
tensor([33571,   324], device='cuda:0')
tensor(3, device='cuda:0')
tensor([33350,   569], device='cuda:0')
tensor(3, device='cuda:0')
tensor([84855,   806], device='cuda:0')
tensor(8, device='cuda:0')
tensor([ 1152, 49707], device='cuda:0')
tensor(5, device='cuda:0')
tensor([35142,  1205], device='cuda:0')
tensor(3, device='cuda:0')
tensor([62597,  1514], device='cuda:0')
tensor(6, device='cuda:0')
tensor([75919,  1514], device='cuda:0')
tensor(0, device='cuda:0')
tensor([ 1534, 49815], device='cuda:0')
tensor(4, device='cuda:0')
tensor([86008,  1750], device='cuda:0')
tensor(7, device='cuda:0')
tensor([49707,  1872], device='cuda:0')
tensor(5, device='cuda:0')
tensor([64041,  2352], device='cuda:0')
tensor(5, device='cuda

In [116]:
labels.unique().shape

torch.Size([98778])

In [63]:
(labels == 245).sum()

tensor(4, device='cuda:0')