In [3]:
%load_ext autoreload
%autoreload 2

import os
import sys

import numpy as np
import pandas as pd
import scipy.sparse as sps
import yaml
from itertools import chain, product, combinations
import torch

from time import time as tt
from tqdm import tqdm
sys.path.append("../../../")
from gnn4itk_cf.stages.data_reading.models.trackml_utils import *

from gnn4itk_cf.stages.data_reading.data_reading_stage import EventReader
from gnn4itk_cf.stages.data_reading.models.trackml_reader import TrackMLReader

from gnn4itk_cf.stages.graph_construction.models.metric_learning import MetricLearning
from gnn4itk_cf.stages.edge_classifier.models.filter import Filter
from gnn4itk_cf.stages.edge_classifier import InteractionGNN

from gnn4itk_cf.stages.graph_construction.utils import handle_weighting
from gnn4itk_cf.stages.graph_construction.models.utils import graph_intersection, build_edges
from gnn4itk_cf.stages.graph_construction.utils import *

from gnn4itk_cf.stages.track_building import utils 
from torch_geometric.utils import to_scipy_sparse_matrix

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
with open("examples/Example_3/metric_learning_train.yaml", "r") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
model_ML = MetricLearning(config)
with open("examples/Example_3/filter_train.yaml", "r") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
model_filter = Filter(config)
with open("examples/Example_3/gnn_train.yaml", "r") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
model_gnn = InteractionGNN(config)
config = yaml.safe_load(open("examples/Example_3/track_building_eval.yaml", "r"))

In [7]:
model_ML.setup(stage="predict")
dataloaders = model_ML.predict_dataloader()
model_filter.setup('predict')
model_gnn.setup('predict')

Loaded 80 training events, 10 validation events and 10 testing events




Defining figures of merit
Defining figures of merit


In [8]:
def load_reconstruction_df(graph):
    """Load the reconstructed tracks from a file."""
    pids = torch.zeros(graph.hit_id.shape[0], dtype=torch.int64)
    pids[graph.track_edges[0]] = graph.particle_id
    pids[graph.track_edges[1]] = graph.particle_id

    return pd.DataFrame({"hit_id": graph.hit_id, "track_id": graph.labels, "particle_id": pids})

def load_particles_df(graph):
    """Load the particles from a file."""
    # Get the particle dataframe
    particles_df = pd.DataFrame({"particle_id": graph.particle_id, "pt": graph.pt})

    # Reduce to only unique particle_ids
    particles_df = particles_df.drop_duplicates(subset=['particle_id'])

    return particles_df

def get_matching_df(reconstruction_df, min_track_length=1, min_particle_length=1):
    
    # Get track lengths
    candidate_lengths = reconstruction_df.track_id.value_counts(sort=False)\
        .reset_index().rename(
            columns={"index":"track_id", "track_id": "n_reco_hits"})

    # Get true track lengths
    particle_lengths = reconstruction_df.drop_duplicates(subset=['hit_id']).particle_id.value_counts(sort=False)\
        .reset_index().rename(
            columns={"index":"particle_id", "particle_id": "n_true_hits"})

    spacepoint_matching = reconstruction_df.groupby(['track_id', 'particle_id']).size()\
        .reset_index().rename(columns={0:"n_shared"})

    spacepoint_matching = spacepoint_matching.merge(candidate_lengths, on=['track_id'], how='left')
    spacepoint_matching = spacepoint_matching.merge(particle_lengths, on=['particle_id'], how='left')
    # spacepoint_matching = spacepoint_matching.merge(particles_df, on=['particle_id'], how='left')

    # Filter out tracks with too few shared spacepoints
    spacepoint_matching["is_matchable"] = spacepoint_matching.n_reco_hits >= min_track_length
    spacepoint_matching["is_reconstructable"] = spacepoint_matching.n_true_hits >= min_particle_length

    return spacepoint_matching

def calculate_matching_fraction(spacepoint_matching_df):
    spacepoint_matching_df = spacepoint_matching_df.assign(
        purity_reco=np.true_divide(spacepoint_matching_df.n_shared, spacepoint_matching_df.n_reco_hits))
    spacepoint_matching_df = spacepoint_matching_df.assign(
        eff_true = np.true_divide(spacepoint_matching_df.n_shared, spacepoint_matching_df.n_true_hits))

    return spacepoint_matching_df

def evaluate_labelled_graph(graph, matching_fraction=0.5, matching_style="ATLAS", min_track_length=1, min_particle_length=1):

    if matching_fraction < 0.5:
        raise ValueError("Matching fraction must be >= 0.5")

    if matching_fraction == 0.5:
        # Add a tiny bit of noise to the matching fraction to avoid double-matched tracks
        matching_fraction += 0.00001

    # Load the labelled graphs as reconstructed dataframes
    reconstruction_df = load_reconstruction_df(graph)
    particles_df = load_particles_df(graph)

    # Get matching dataframe
    matching_df = get_matching_df(reconstruction_df, particles_df, min_track_length=min_track_length, min_particle_length=min_particle_length) 
    matching_df["event_id"] = int(graph.event_id)

    # calculate matching fraction
    matching_df = calculate_matching_fraction(matching_df)

    # Run matching depending on the matching style
    if matching_style == "ATLAS":
        matching_df["is_matched"] = matching_df["is_reconstructed"] = matching_df.purity_reco >= matching_fraction
    elif matching_style == "one_way":
        matching_df["is_matched"] = matching_df.purity_reco >= matching_fraction
        matching_df["is_reconstructed"] = matching_df.eff_true >= matching_fraction
    elif matching_style == "two_way":
        matching_df["is_matched"] = matching_df["is_reconstructed"] = (matching_df.purity_reco >= matching_fraction) & (matching_df.eff_true >= matching_fraction)

    return matching_df

In [9]:
def evaluate_labelled_graphs(graphset, config):
    all_y_truth, all_pt  = [], []
    evaluated_events = [
        utils.evaluate_labelled_graph(
            event,
            matching_fraction=config["matching_fraction"],
            matching_style=config["matching_style"],
            min_track_length=config["min_track_length"],
            min_particle_length=config["min_particle_length"],
        )
        for event in tqdm(graphset)
    ]
    evaluated_events = pd.concat(evaluated_events)

    particles = evaluated_events[evaluated_events["is_reconstructable"]]
    reconstructed_particles = particles[particles["is_reconstructed"] & particles["is_matchable"]]
    tracks = evaluated_events[evaluated_events["is_matchable"]]
    matched_tracks = tracks[tracks["is_matched"]]

    n_particles = len(particles.drop_duplicates(subset=['event_id', 'particle_id']))
    n_reconstructed_particles = len(reconstructed_particles.drop_duplicates(subset=['event_id', 'particle_id']))

    n_tracks = len(tracks.drop_duplicates(subset=['event_id', 'track_id']))
    n_matched_tracks = len(matched_tracks.drop_duplicates(subset=['event_id', 'track_id']))

    n_dup_reconstructed_particles = len(reconstructed_particles) - n_reconstructed_particles

    print(f"Number of reconstructed particles: {n_reconstructed_particles}")
    print(f"Number of particles: {n_particles}")
    print(f"Number of matched tracks: {n_matched_tracks}")
    print(f"Number of tracks: {n_tracks}")
    print(f"Number of duplicate reconstructed particles: {n_dup_reconstructed_particles}")   

    # Plot the results across pT and eta
    eff = n_reconstructed_particles / n_particles
    fake_rate = 1 - (n_matched_tracks / n_tracks)
    dup_rate = n_dup_reconstructed_particles / n_reconstructed_particles

    logging.info(f"Efficiency: {eff:.3f}")
    logging.info(f"Fake rate: {fake_rate:.3f}")
    logging.info(f"Duplication rate: {dup_rate:.3f}")
    print(f"Efficiency: {eff:.3f}")
    print(f"Fake rate: {fake_rate:.3f}")
    print(f"Duplication rate: {dup_rate:.3f}")

In [10]:
device ='cuda'
model_ML = model_ML.to("cuda")
model_filter = model_filter.to("cuda")
model_gnn = model_gnn.to("cuda")
for batch in dataloaders[2]:
    batch = batch.to("cuda")
    with torch.no_grad():
        if device == 'cuda':
            with torch.cuda.amp.autocast():
                embedding = model_ML.apply_embedding(batch)
     
    batch.edge_index = build_edges(
        query=embedding, database=embedding, indices=None, r_max=0.1, k_max=10, backend="FRNN"
    )
    R = batch.r**2 + batch.z**2
    flip_edge_mask = R[batch.edge_index[0]] > R[batch.edge_index[1]]
    batch.edge_index[:, flip_edge_mask] = batch.edge_index[:, flip_edge_mask].flip(0)
    with torch.no_grad():
        if device == 'cuda':
            with torch.cuda.amp.autocast():
                out = model_filter(batch)   
    preds = torch.sigmoid(out)
    batch.edge_index = batch.edge_index[:, preds > model_filter.hparams['edge_cut']]
    with torch.no_grad():
        if device == 'cuda':
            with torch.cuda.amp.autocast():
                out = model_gnn(batch)
    batch.scores = torch.sigmoid(out)

    edge_mask = batch.scores > model_gnn.hparams['edge_cut'] 
    # Get number of nodes
    if hasattr(batch, "num_nodes"):
        num_nodes = batch.num_nodes
    elif hasattr(batch, "x"):
        num_nodes = batch.x.size(0)
    elif hasattr(batch, "x_x"):
        num_nodes = batch.x_x.size(0)
    else:
        num_nodes = batch.edge_index.max().item() + 1
    # Convert to sparse scipy array
    sparse_edges = to_scipy_sparse_matrix(
        batch.edge_index[:, edge_mask], num_nodes=num_nodes
    )
    # Run connected components
    candidate_labels = sps.csgraph.connected_components(
        sparse_edges, directed=False, return_labels=True
    )
    batch.labels = torch.from_numpy(candidate_labels[1]).long()
    
    batch.config.append(config)
    print(batch)
    evaluate_labelled_graphs([batch.to('cpu')], config)


DataBatch(cell_val=[15682], geta=[15682], weight=[15682], region=[15682], lx=[15682], lphi=[15682], module_index=[15682], x=[15682], r=[15682], gphi=[15682], hit_id=[15682], lz=[15682], z=[15682], cell_count=[15682], phi=[15682], y=[15682], ly=[15682], leta=[15682], eta=[15682], track_edges=[2, 14278], particle_id=[14278], radius=[14278], nhits=[14278], pt=[14278], config=[2], event_id=[1], num_nodes=15682, batch=[15682], ptr=[2], edge_index=[2, 111209], scores=[111209], labels=[15682])


100%|██████████| 1/1 [00:00<00:00, 17.73it/s]

Number of reconstructed particles: 1
Number of particles: 1386
Number of matched tracks: 1
Number of tracks: 98
Number of duplicate reconstructed particles: 0
Efficiency: 0.001
Fake rate: 0.990
Duplication rate: 0.000





DataBatch(cell_val=[14296], geta=[14296], weight=[14296], region=[14296], lx=[14296], lphi=[14296], module_index=[14296], x=[14296], r=[14296], gphi=[14296], hit_id=[14296], lz=[14296], z=[14296], cell_count=[14296], phi=[14296], y=[14296], ly=[14296], leta=[14296], eta=[14296], track_edges=[2, 13029], particle_id=[13029], radius=[13029], nhits=[13029], pt=[13029], config=[2], event_id=[1], num_nodes=14296, batch=[14296], ptr=[2], edge_index=[2, 102476], scores=[102476], labels=[14296])


100%|██████████| 1/1 [00:00<00:00, 34.67it/s]

Number of reconstructed particles: 2
Number of particles: 1244
Number of matched tracks: 2
Number of tracks: 98
Number of duplicate reconstructed particles: 0
Efficiency: 0.002
Fake rate: 0.980
Duplication rate: 0.000





DataBatch(cell_val=[14783], geta=[14783], weight=[14783], region=[14783], lx=[14783], lphi=[14783], module_index=[14783], x=[14783], r=[14783], gphi=[14783], hit_id=[14783], lz=[14783], z=[14783], cell_count=[14783], phi=[14783], y=[14783], ly=[14783], leta=[14783], eta=[14783], track_edges=[2, 13474], particle_id=[13474], radius=[13474], nhits=[13474], pt=[13474], config=[2], event_id=[1], num_nodes=14783, batch=[14783], ptr=[2], edge_index=[2, 106243], scores=[106243], labels=[14783])


100%|██████████| 1/1 [00:00<00:00, 35.43it/s]

Number of reconstructed particles: 2
Number of particles: 1288
Number of matched tracks: 2
Number of tracks: 102
Number of duplicate reconstructed particles: 0
Efficiency: 0.002
Fake rate: 0.980
Duplication rate: 0.000





DataBatch(cell_val=[14890], geta=[14890], weight=[14890], region=[14890], lx=[14890], lphi=[14890], module_index=[14890], x=[14890], r=[14890], gphi=[14890], hit_id=[14890], lz=[14890], z=[14890], cell_count=[14890], phi=[14890], y=[14890], ly=[14890], leta=[14890], eta=[14890], track_edges=[2, 13576], particle_id=[13576], radius=[13576], nhits=[13576], pt=[13576], config=[2], event_id=[1], num_nodes=14890, batch=[14890], ptr=[2], edge_index=[2, 106951], scores=[106951], labels=[14890])


100%|██████████| 1/1 [00:00<00:00, 32.33it/s]

Number of reconstructed particles: 1
Number of particles: 1292
Number of matched tracks: 1
Number of tracks: 93
Number of duplicate reconstructed particles: 0
Efficiency: 0.001
Fake rate: 0.989
Duplication rate: 0.000





DataBatch(cell_val=[13415], geta=[13415], weight=[13415], region=[13415], lx=[13415], lphi=[13415], module_index=[13415], x=[13415], r=[13415], gphi=[13415], hit_id=[13415], lz=[13415], z=[13415], cell_count=[13415], phi=[13415], y=[13415], ly=[13415], leta=[13415], eta=[13415], track_edges=[2, 12206], particle_id=[12206], radius=[12206], nhits=[12206], pt=[12206], config=[2], event_id=[1], num_nodes=13415, batch=[13415], ptr=[2], edge_index=[2, 94924], scores=[94924], labels=[13415])


100%|██████████| 1/1 [00:00<00:00, 36.71it/s]

Number of reconstructed particles: 3
Number of particles: 1190
Number of matched tracks: 3
Number of tracks: 92
Number of duplicate reconstructed particles: 0
Efficiency: 0.003
Fake rate: 0.967
Duplication rate: 0.000





DataBatch(cell_val=[11113], geta=[11113], weight=[11113], region=[11113], lx=[11113], lphi=[11113], module_index=[11113], x=[11113], r=[11113], gphi=[11113], hit_id=[11113], lz=[11113], z=[11113], cell_count=[11113], phi=[11113], y=[11113], ly=[11113], leta=[11113], eta=[11113], track_edges=[2, 10146], particle_id=[10146], radius=[10146], nhits=[10146], pt=[10146], config=[2], event_id=[1], num_nodes=11113, batch=[11113], ptr=[2], edge_index=[2, 78081], scores=[78081], labels=[11113])


100%|██████████| 1/1 [00:00<00:00, 38.24it/s]

Number of reconstructed particles: 5
Number of particles: 955
Number of matched tracks: 5
Number of tracks: 100
Number of duplicate reconstructed particles: 0
Efficiency: 0.005
Fake rate: 0.950
Duplication rate: 0.000





DataBatch(cell_val=[7859], geta=[7859], weight=[7859], region=[7859], lx=[7859], lphi=[7859], module_index=[7859], x=[7859], r=[7859], gphi=[7859], hit_id=[7859], lz=[7859], z=[7859], cell_count=[7859], phi=[7859], y=[7859], ly=[7859], leta=[7859], eta=[7859], track_edges=[2, 7154], particle_id=[7154], radius=[7154], nhits=[7154], pt=[7154], config=[2], event_id=[1], num_nodes=7859, batch=[7859], ptr=[2], edge_index=[2, 56255], scores=[56255], labels=[7859])


100%|██████████| 1/1 [00:00<00:00, 46.79it/s]

Number of reconstructed particles: 6
Number of particles: 692
Number of matched tracks: 6
Number of tracks: 83
Number of duplicate reconstructed particles: 0
Efficiency: 0.009
Fake rate: 0.928
Duplication rate: 0.000





DataBatch(cell_val=[12086], geta=[12086], weight=[12086], region=[12086], lx=[12086], lphi=[12086], module_index=[12086], x=[12086], r=[12086], gphi=[12086], hit_id=[12086], lz=[12086], z=[12086], cell_count=[12086], phi=[12086], y=[12086], ly=[12086], leta=[12086], eta=[12086], track_edges=[2, 11006], particle_id=[11006], radius=[11006], nhits=[11006], pt=[11006], config=[2], event_id=[1], num_nodes=12086, batch=[12086], ptr=[2], edge_index=[2, 85080], scores=[85080], labels=[12086])


100%|██████████| 1/1 [00:00<00:00, 42.13it/s]

Number of reconstructed particles: 3
Number of particles: 1069
Number of matched tracks: 3
Number of tracks: 101
Number of duplicate reconstructed particles: 0
Efficiency: 0.003
Fake rate: 0.970
Duplication rate: 0.000





DataBatch(cell_val=[14300], geta=[14300], weight=[14300], region=[14300], lx=[14300], lphi=[14300], module_index=[14300], x=[14300], r=[14300], gphi=[14300], hit_id=[14300], lz=[14300], z=[14300], cell_count=[14300], phi=[14300], y=[14300], ly=[14300], leta=[14300], eta=[14300], track_edges=[2, 13026], particle_id=[13026], radius=[13026], nhits=[13026], pt=[13026], config=[2], event_id=[1], num_nodes=14300, batch=[14300], ptr=[2], edge_index=[2, 101029], scores=[101029], labels=[14300])


100%|██████████| 1/1 [00:00<00:00, 37.28it/s]

Number of reconstructed particles: 3
Number of particles: 1248
Number of matched tracks: 3
Number of tracks: 100
Number of duplicate reconstructed particles: 0
Efficiency: 0.002
Fake rate: 0.970
Duplication rate: 0.000





DataBatch(cell_val=[11896], geta=[11896], weight=[11896], region=[11896], lx=[11896], lphi=[11896], module_index=[11896], x=[11896], r=[11896], gphi=[11896], hit_id=[11896], lz=[11896], z=[11896], cell_count=[11896], phi=[11896], y=[11896], ly=[11896], leta=[11896], eta=[11896], track_edges=[2, 10851], particle_id=[10851], radius=[10851], nhits=[10851], pt=[10851], config=[2], event_id=[1], num_nodes=11896, batch=[11896], ptr=[2], edge_index=[2, 84296], scores=[84296], labels=[11896])


100%|██████████| 1/1 [00:00<00:00, 39.68it/s]

Number of reconstructed particles: 2
Number of particles: 1030
Number of matched tracks: 2
Number of tracks: 98
Number of duplicate reconstructed particles: 0
Efficiency: 0.002
Fake rate: 0.980
Duplication rate: 0.000





In [11]:
export_output = torch.onnx.export(model_filter, sample,"filter.onnx")

RuntimeError: Only tuples, lists and Variables are supported as JIT inputs/outputs. Dictionaries and strings are also accepted, but their usage is not recommended. Here, received an input of unsupported type: Data