# Pipeline Testing

In [1]:
%load_ext autoreload
%autoreload 2

In [10]:
import os, sys
import torch
import pytorch_lightning as pl
import yaml
import importlib
sys.path.append('..')
from LightningModules.Embedding.layerless_embedding import LayerlessEmbedding, EmbeddingInferenceCallback
from LightningModules.Embedding.utils import get_best_run, build_edges, res, graph_intersection
from LightningModules.Filter.utils import stringlist_to_classes
from LightningModules.Filter.vanilla_filter import VanillaFilter, FilterInferenceCallback
from LightningModules.Processing.feature_construction import FeatureStore
from pytorch_lightning.loggers import WandbLogger

# Preprocessing

## Data Loading

In [6]:
with open("LightningModules/Processing/prepare_feature_store.yaml") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

In [7]:
preprocess_dm = FeatureStore(config)

In [8]:
preprocess_dm.prepare_data()

Loading detector...
Detector loaded.
Writing outputs to /global/cscratch1/sd/danieltm/ExaTrkX/trackml/feature_store_endcaps
/global/cscratch1/sd/danieltm/ExaTrkX/trackml/train_all/event000001000
Preparing 1000
Layerless truth graph built for /global/cscratch1/sd/danieltm/ExaTrkX/trackml/train_all/event000001000 with size (2, 123429)
Cell features for 1000
Loading event /global/cscratch1/sd/danieltm/ExaTrkX/trackml/train_all/event000001000 with a 0 pT cut
/global/cscratch1/sd/danieltm/ExaTrkX/trackml/train_all/event000001001
Preparing 1001
Layerless truth graph built for /global/cscratch1/sd/danieltm/ExaTrkX/trackml/train_all/event000001001 with size (2, 91386)
Cell features for 1001
Loading event /global/cscratch1/sd/danieltm/ExaTrkX/trackml/train_all/event000001001 with a 0 pT cut
/global/cscratch1/sd/danieltm/ExaTrkX/trackml/train_all/event000001002
Preparing 1002
Layerless truth graph built for /global/cscratch1/sd/danieltm/ExaTrkX/trackml/train_all/event000001002 with size (2, 1289

KeyboardInterrupt: 

In [11]:
data = torch.load("/global/cscratch1/sd/danieltm/ExaTrkX/trackml/feature_store_endcaps/1000")

In [13]:
data

Data(cell_data=[103305, 9], event_file=/global/cscratch1/sd/danieltm/ExaTrkX/trackml/train_all/event000001000, hid=[103305], layerless_true_edges=[2, 123429], layers=[103305], pid=[103305], x=[103305, 3])

# Embedding

## Model Loading

In [3]:
with open("LightningModules/Embedding/train_embedding.yaml") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

In [4]:
model = LayerlessEmbedding(config)

Optionally load the Weights & Biases logger

In [None]:
wandb_dir = "/global/cscratch1/sd/danieltm/ExaTrkX/wandb_data"
wandb_logger = WandbLogger(project="EmbeddingStudy", group="LayerlessEndcaps", log_model=True, save_dir = wandb_dir)

In [6]:
trainer = pl.Trainer(max_epochs = config['max_epochs'], gpus=1, logger=wandb_logger, callbacks=stringlist_to_classes(config["callbacks"]))

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


In [7]:
trainer.fit(model)

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
[34m[1mwandb[0m: Currently logged in as: [33mmurnanedaniel[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.5 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
[34m[1mwandb[0m: Tracking run with wandb version 0.10.2
[34m[1mwandb[0m: Run data is saved locally in /global/cscratch1/sd/danieltm/ExaTrkX/wandb_data/wandb/run-20201013_111443-hd6lqvip
[34m[1mwandb[0m: Syncing run [33mworthy-frog-58[0m


Set SLURM handle signals.

  | Name      | Type       | Params
-----------------------------------------
0 | layers    | ModuleList | 1 M   
1 | emb_layer | Linear     | 4 K   
2 | norm      | LayerNorm  | 1 K   
3 | act       | Tanh       | 0     





HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Saving latest checkpoint..





KeyboardInterrupt: 

## Callback Testing

Add to the callback list any data manipulation methods. For example, EmbeddingInferenceCallback automatically builds the training, validation and testing set for the next stage of the pipeline after training.

In [29]:
callback_list = [EmbeddingInferenceCallback()]

In [6]:
trainer = pl.Trainer(max_epochs = config['max_epochs'], gpus=1, logger=wandb_logger, callbacks=callback_list)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


   ##  Model Load and Test

In [5]:
run_label = "hd6lqvip"

In [6]:
best_run_path = get_best_run(run_label,wandb_dir)

In [8]:
chkpnt = torch.load(best_run_path)

In [9]:
model = model.load_from_checkpoint(best_run_path)

In [14]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

In [32]:
model.eval()
with torch.no_grad():
    cluster_total_positive, cluster_total_true, cluster_total_true_positive = 0, 0, 0
    for i, batch in enumerate(model.val_dataloader()):
            data = batch.to(device)
            if 'ci' in chkpnt["hyper_parameters"]['regime']:
                spatial = model(torch.cat([data.cell_data, data.x], axis=-1))
            else:
                spatial = model(data.x)
            e_spatial = build_edges(spatial, 1.7, 500, res)  
            e_bidir = torch.cat([batch.layerless_true_edges.to(device), 
                                   torch.stack([batch.layerless_true_edges[1], batch.layerless_true_edges[0]], axis=1).T.to(device)], axis=-1) 
            e_spatial, y_cluster = graph_intersection(e_spatial, e_bidir)
            
            #Cluster performance
            cluster_true = 2*len(batch.layerless_true_edges[0])
            cluster_true_positive = y_cluster.sum()
            cluster_positive = len(e_spatial[0])
            
            cluster_total_true_positive += cluster_true_positive
            cluster_total_positive += cluster_positive
            cluster_total_true += cluster_true
            if i % 5 == 0:
                print(i, "validated")

    cluster_eff = (cluster_total_true_positive / max(cluster_total_true, 1))
    cluster_pur = (cluster_total_true_positive / max(cluster_total_positive, 1))
print("Eff:", cluster_eff, "Pur:", cluster_pur)

0 validated
5 validated
10 validated
15 validated
20 validated
25 validated
30 validated
35 validated
40 validated
45 validated
Eff: 0.9750146118157859 Pur: 0.013130086934014518


## Build Filter Set

In [38]:
batch.x

tensor([[ 0.0315,  0.2140,  0.3112],
        [ 0.0324, -0.6945,  0.1345],
        [ 0.1155,  0.0766, -0.1548],
        ...,
        [ 0.5474,  0.5934,  1.5045],
        [ 0.8606, -0.7056,  2.9445],
        [ 0.8391,  0.6329,  2.9445]])

In [39]:
model.train_dataloader().dataset

<torch.utils.data.dataset.Subset at 0x2aab7eece050>

In [55]:
import numpy as np
from numpy.random import shuffle, choice
from time import time as tt
import os

save_dir = "/global/cscratch1/sd/danieltm/ExaTrkX/trackml_processed/embedding_processed/0_pt_cut_endcaps/train"
train, ratio = False, 8

model.eval()
with torch.no_grad():
    for i, batch in enumerate(model.train_dataloader().dataset):
            tic = tt()
            if not os.path.exists(os.path.join(save_dir, batch.event_file[-4:])):

                data = batch.to(device)
                if 'ci' in chkpnt["hyper_parameters"]['regime']:
                    spatial = model(torch.cat([data.cell_data, data.x], axis=-1))
                else:
                    spatial = model(data.x)
                e_spatial = build_edges(spatial, 1.7, 500, res)  
                e_bidir = torch.cat([batch.layerless_true_edges.to(device), 
                                       torch.stack([batch.layerless_true_edges[1], batch.layerless_true_edges[0]], axis=1).T.to(device)], axis=-1) 
                e_spatial, y_cluster = graph_intersection(e_spatial, e_bidir)

                # Remove duplicate edges by distance from vertex
                R_dist = torch.sqrt(batch.x[:,0]**2 + batch.x[:,2]**2)
                e_spatial = e_spatial[:, (R_dist[e_spatial[0]] < R_dist[e_spatial[1]])]

                e_spatial, y = graph_intersection(e_spatial, e_bidir)  

                # Re-introduce random direction, to avoid training bias
                random_flip = torch.randint(2, (e_spatial.shape[1],)).bool()
                e_spatial[0, random_flip], e_spatial[1, random_flip] = e_spatial[1, random_flip], e_spatial[0, random_flip]

                batch.embedding = spatial.cpu().detach()

                if train and (ratio != 0): # Sample only ratio:1 fake:true edges, to keep trainset manageable

                    num_true = y.sum()
                    fake_indices = choice(np.where(~y)[0], int(num_true*ratio), replace=True)
                    true_indices = np.where(y)[0]
                    combined_indices = np.concatenate([true_indices, fake_indices])
                    shuffle(combined_indices)

                    batch.e_radius = e_spatial[:,combined_indices].cpu()
                    batch.y = torch.from_numpy(y[combined_indices]).float()

                else:
                    batch.e_radius = e_spatial.cpu()
                    batch.y = torch.from_numpy(y).float()


                with open(os.path.join(save_dir, batch.event_file[-4:]), 'wb') as pickle_file:
                    torch.save(batch, pickle_file)

                print(i, "saved in time", tt()-tic, "with efficiency", (batch.y.sum()/batch.layerless_true_edges.shape[1]).item(), "and purity", (batch.y.sum()/batch.e_radius.shape[1]).item())

            else:
                print(i, "already built")

0 already built
1 already built
2 already built
3 already built
4 already built
5 already built
6 already built
7 already built
8 already built
9 already built
10 already built
11 already built
12 already built
13 already built
14 already built
15 already built
16 already built
17 already built
18 already built
19 already built
20 already built
21 already built
22 already built
23 already built
24 already built
25 already built
26 already built
27 already built
28 already built
29 already built
30 already built
31 already built
32 already built
33 already built
34 already built
35 already built
36 already built
37 already built
38 already built
39 already built
40 already built
41 already built
42 already built
43 already built
44 already built
45 already built
46 already built
47 already built
48 already built
49 already built
50 already built
51 already built
52 already built
53 already built
54 already built
55 already built
56 already built
57 already built
58 already built
59 save

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/global/homes/d/danieltm/.local/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3331, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-55-89672923ace6>", line 29, in <module>
    e_spatial, y = graph_intersection(e_spatial, e_bidir)
  File "/global/u2/d/danieltm/ExaTrkX/Tracking-ML-Exa.TrkX/src/Pipelines/Examples/LightningModules/Embedding/utils.py", line 14, in graph_intersection
    l1 = pred_graph.cpu().numpy()
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/global/homes/d/danieltm/.local/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 2044, in showtraceback
    stb = value._render_traceback_()
AttributeError: 'KeyboardInterrupt' object has no attribute '_render_traceback_'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  

KeyboardInterrupt: 

# Filter

## Model Loading

In [20]:
with open("LightningModules/Filter/train_filter.yaml") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

In [4]:
model = VanillaFilter(config)

Optionally load the Weights & Biases logger

In [None]:
wandb_dir = "/global/cscratch1/sd/danieltm/ExaTrkX/wandb_data"
wandb_logger = WandbLogger(project="FilteringStudy", group="LayerlessEndcaps", log_model=True, save_dir = wandb_dir)

## Callback Testing

In [6]:
trainer = pl.Trainer(max_epochs = config['max_epochs'], gpus=1, logger=wandb_logger, callbacks=stringlist_to_classes(config["callbacks"]))

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


In [None]:
trainer.fit(model)

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
[34m[1mwandb[0m: Currently logged in as: [33mmurnanedaniel[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.5 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
[34m[1mwandb[0m: Tracking run with wandb version 0.10.2
[34m[1mwandb[0m: Run data is saved locally in /global/cscratch1/sd/danieltm/ExaTrkX/wandb_data/wandb/run-20201013_160852-4uovj4x7
[34m[1mwandb[0m: Syncing run [33mglorious-bird-26[0m


Set SLURM handle signals.

  | Name         | Type        | Params
---------------------------------------------
0 | input_layer  | Linear      | 12 K  
1 | layers       | ModuleList  | 525 K 
2 | output_layer | Linear      | 513   
3 | layernorm    | LayerNorm   | 1 K   
4 | batchnorm    | BatchNorm1d | 1 K   
5 | act          | Tanh        | 0     





HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…