# Embedding Model and The Dataset Class

In this notebook we develop the next step in our GNN model. To apply either the _meta-layer_ approch or the  or the _ParticleNet_ approch we need to first have a graph structure in hand. 

Comming up with direct mehods to make a graph structure from the processed event data is not encouraged, considering the huge number of hits / nodes ($\approx 10^6$). Direct method's like compaeing distances and angles are computatinally expensive. Insted we train a multylayer perceptron (MLP) in such a way so that points belonging to the same track are close in the latent space, which will allow us to apply a effective knn algorithm to construct the graph strcuture. 
This intermidiate ML model is reffered to as the embeddding model. [1]

[1] : Ju, Xiangyang, et al. "Performance of a geometric deep learning pipeline for HL-LHC particle tracking." The European Physical Journal C 81 (2021): 1-14.

In [1]:
import os
import time
import random
import numpy as np

from scipy.stats import ortho_group

from typing import Optional, Tuple


import torch
import torch.nn as nn
from torch.linalg import norm 
import torch.utils.data as data
import torch.nn.functional as F
from torch.nn import Linear, ReLU, BatchNorm1d, Module, Sequential
from torch import Tensor

# torch.set_default_dtype(torch.float64)

from torch_geometric.typing import (
    Adj,
    OptPairTensor,
    OptTensor,
    Size,
    SparseTensor,
    torch_sparse,
)

import torch_geometric
from torch_geometric.data import Data
from torch_geometric.data import Batch
import torch_geometric.transforms as T
from torch.utils.data.sampler import SubsetRandomSampler
from torch_geometric.utils import remove_self_loops, to_dense_adj, dense_to_sparse, is_undirected , to_undirected, contains_self_loops , to_networkx
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing, global_mean_pool, knn_graph
from torch_geometric.datasets import QM9
# from torch_scatter import scatter
# from torch_cluster import knn

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
# import uproot
import vector
vector.register_awkward()
import awkward as ak

from IPython.display import HTML

print("PyTorch version {}".format(torch.__version__))
print("PyG version {}".format(torch_geometric.__version__))

PyTorch version 2.6.0
PyG version 2.6.1


In [2]:
seed = 5 
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)

In [3]:
from tqdm import trange, tqdm

In [4]:
from TrackML import Preprocessing
from TrackML.Models.utils import buildMLP 

In [5]:
# set default params for matplotlib : 
plt.rcParams['font.size'] = 14
plt.rcParams['lines.linewidth'] = 2
plt.rcParams["figure.figsize"] = (10,7)

In [6]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

Using device: mps


In [7]:
from sklearn.preprocessing import StandardScaler

### 1. Define Point Cloud Class for Processed Data 
We define a point cloud dataset class to train our embedding model with. 

In [8]:
detector_path = '../data/detectors.csv'
dataset_path = '../data/train_100_events/'

In [9]:
class PointCloudData(data.Dataset): 
    
    # initialize the dataset class : 
    def __init__(self,dataset_path:str,detector_path:str,min_nhits=3)->None:
        '''
        dataset_path : path to the dataset with the events. 
        eventids : list of eventid identifiers. 
        '''
        super().__init__()
        
        self.dataset_path = dataset_path
        self.min_nhits = min_nhits 
        
        
        self.detector = Preprocessing.load_detector_data(detector_path) 
        
        # get the list of event ids from the dataset folder : 
        eventids = [ code[:-9] for code in os.listdir(dataset_path) if code.endswith('-hits.csv') ]
        self.eventids = eventids
        
        
    def __len__(self)->int: 
        return len( self.eventids )
    
    def __getitem__(self, index):
        return (
            Preprocessing.process_event_data(
                train_path=self.dataset_path, 
                eventid=self.eventids[index], 
                detector=self.detector
            ), 
            Preprocessing.process_particle_labels(
                train_path=self.dataset_path, 
                eventid=self.eventids[index], 
                min_nhits=self.min_nhits 
            )
        )

In [10]:
point_cloud_dataset = PointCloudData(dataset_path,detector_path)
len(point_cloud_dataset)

100

In [11]:
def collate_function(data_list): 
    return data_list[0] 

In [12]:
# initialize the data loader object ! 
point_clout_data_loder = data.DataLoader(
    dataset=point_cloud_dataset, 
    batch_size=1, 
    shuffle=True ,
    collate_fn=collate_function
)

In [13]:
node_feats , labels  = next(iter(point_clout_data_loder))
node_feats.shape , labels.shape 

(torch.Size([120246, 15]), torch.Size([120246]))

### 2. Create the Base Embedding Model.

The model architecture is just a normal MLP with relu acctivation for hidden layers

In [14]:
class EmbeddingModel(nn.Module):
    
    def __init__(
        self, 
        in_features:int, 
        hidden_features:list, 
        out_features:int 
    )->None:
        
        super().__init__()
        
        # define the MLP layer : 
        self.MLP = buildMLP(
            insize=in_features, 
            outsize=out_features, 
            features=hidden_features,
            add_bnorm=True,
            add_activation=nn.BatchNorm1d(out_features)
        )
        
    def forward(self,x:Tensor)->Tensor: 
        return self.MLP(x)

In [15]:
example_embedding_model = EmbeddingModel(
    in_features=15, 
    hidden_features=[20,25,15,10],
    out_features=5
)
example_embedding_model

EmbeddingModel(
  (MLP): Sequential(
    (0): Linear(in_features=15, out_features=20, bias=True)
    (1): ReLU()
    (2): BatchNorm1d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Linear(in_features=20, out_features=25, bias=True)
    (4): ReLU()
    (5): BatchNorm1d(25, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Linear(in_features=25, out_features=15, bias=True)
    (7): ReLU()
    (8): BatchNorm1d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): Linear(in_features=15, out_features=10, bias=True)
    (10): ReLU()
    (11): Linear(in_features=10, out_features=5, bias=True)
    (12): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

In [16]:
example_embedding_model.parameters()

<generator object Module.parameters at 0x13c008820>

In [17]:
node_feats.mean(dim=0 , keepdim=True).shape 

torch.Size([1, 15])

In [18]:
(node_feats/node_feats.std(dim=0 , keepdim=True)).shape

torch.Size([120246, 15])

In [19]:
example_model_out = example_embedding_model(node_feats)
example_model_out.shape 

torch.Size([120246, 5])

### 3. Latent Space Analysis 

Here we create a graph structure by using a _radius\_graph_ algorithm, on the latent space featuers. We set the radius hyperparameter same as the margin for our pairwise hinge loss function (more on that later). 

In [20]:
margin = .01
max_num_neighbors=12

radius_graph_edge_index = torch_geometric.nn.pool.radius_graph(
    example_model_out, 
    r = margin , 
    loop = False , 
    max_num_neighbors=max_num_neighbors
)
radius_graph_edge_index
radius_graph_edge_index.shape 

torch.Size([2, 402721])

In [21]:
example_model_out.shape

torch.Size([120246, 5])

In [22]:
# check if the graph is undirected 
is_undirected( radius_graph_edge_index )

False

In [23]:
# convert the graph to a directed one 
radius_graph_edge_index = to_undirected( radius_graph_edge_index )
radius_graph_edge_index.shape 

torch.Size([2, 461106])

In [24]:
# Now each pair had a duplicate copy in the indices list, 
# since the graph is undirected now. 
row , col = radius_graph_edge_index

# Create a mask to get those pairs which have different 
# labels and avoid repeations by choosing row < col
mask = ( row < col ) & ( labels[row] != labels[col] )

negetive_pair_indices = radius_graph_edge_index[:,mask]
del mask , row , col 

In [25]:
negetive_pair_indices.shape

torch.Size([2, 223168])

### 4. Create Custom Loss Function: 

We use pairwise hinge loss functions on all pairs of the hits. For this we use the Hinge Loss (`torch.nn.HingeEmbeddingLoss` see [documentation](https://pytorch.org/docs/stable/generated/torch.nn.HingeEmbeddingLoss.html#torch.nn.HingeEmbeddingLoss)). 

The loss function is given as: 
$$
l_n = \begin{cases}
    x_n, &\text{ if } y_n = 1 \\
    max\{0,\text{margin}-x_n\}, &\text{ if } y_n = -1 
\end{cases}
$$

Note that we cannot afford to calculate pairwise distance for all the hit pairs in the latent space, since that is computationally too expenceive. Insted we calculate the loss using a trick to optimize the calculation. 

We seperately calculate the loss for all the correct hit pairs and calculate the loss for all the incorrect hit pairs that are formed post a knn cluster calculation. This way we avoid the redundent zeros that come up in direct pairwise loss calculations. 

In [26]:
positive_pair_indices = Preprocessing.get_track_index_pairs(labels)

In [27]:
positive_pair_indices.shape , negetive_pair_indices.shape 

(torch.Size([2, 538549]), torch.Size([2, 223168]))

In [28]:
prow , pcol = positive_pair_indices  
nrow , ncol = negetive_pair_indices

# since the arry sizes are too large we apply batching for 
# pairwise distance calculations: 

loss_plus = norm(
    example_model_out[prow , : ] - example_model_out[pcol, :] , 
    ord = 2 , 
    dim = -1 
)

loss_minus = norm(
    example_model_out[ncol,:] - example_model_out[nrow,:], 
    ord = 2 , 
    dim = -1 
)

print( loss_plus , loss_minus )

loss = (margin - loss_minus).sum() + loss_plus.sum()

loss 

tensor([0.4786, 1.1246, 1.7474,  ..., 0.2403, 0.2376, 0.0094],
       grad_fn=<LinalgVectorNormBackward0>) tensor([0.0087, 0.0034, 0.0098,  ..., 0.0034, 0.0060, 0.0039],
       grad_fn=<LinalgVectorNormBackward0>)


tensor(989658.8125, grad_fn=<AddBackward0>)

In [None]:
def EmbeddingLossFunction(
        x:Tensor,
        labels:Tensor,
        radius_graph_edge_index:Tensor,
        margin:float=.01
    )->Tensor: 
    '''
    x  : model output in the latent space. shape = [nhits, out_feats]
    radius_graph_edge_index : graph formed with the radius ball algorithm. 
    positive_idx : pairs of indices of hits sharing the same 
        particle id. shape = [2,num_positive_pairs]. 
    negetive_idx : pairs of hits indices within the latent space 
        margin radius ball having different particle ids. 
    returns : the pair wise hinge loss. 
    '''
    
    radius_graph_edge_index = to_undirected( radius_graph_edge_index )
    
    # get the positive indices 
    positive_idx = Preprocessing.get_track_index_pairs(labels)
    
    # get the negetive indices pairs that lie within the margin ball : 
    row , col = radius_graph_edge_index

    # Create a mask to get those pairs which have different 
    # labels and avoid repeations by choosing row < col
    mask = ( row < col ) & ( labels[row] != labels[col] )
    negetive_idx = radius_graph_edge_index[:,mask]
    
    # get the positive row and col idx : 
    prow , pcol = positive_idx 
    
    # get the negetive row and col idx : 
    nrow , ncol = negetive_idx
    
    # delete variables not in use : 
    del positive_idx  , negetive_idx , radius_graph_edge_index
    
    # loss form the positive pairs 
    loss_plus = norm(
    x[prow , : ] - x[pcol, :] , 
    ord = 2 , 
    dim = -1 
    )

    # distance between negetive pairs that lie within the margin
    loss_minus = norm(
        x[ncol,:] - x[nrow,:], 
        ord = 2 , 
        dim = -1 
    )
    
    del pcol , ncol , prow , ncol , x 
    
    foo1 = (margin - loss_minus).sum()
    foo2 = loss_plus.sum()
    # final loss 
    loss = torch.log(foo1) + torch.log(1 + foo2/foo1 )

    return loss     

### 5. Define Graph Structure

Use knn statergy in the latent space to create graph structure. 

In [None]:
# def create_graph_structure(latent_space_hits:Tensor,k:int): 
#     return torch_geometric.nn.pool.knn_graph(x = latent_space_hits, k = k )

# k=5 
# example_edge_indeces = create_graph_structure(example_model_out,k=k)
# example_edge_indeces.shape 

# get PyG grapg data structure for the knn buil graph. 
example_graph_data = Data(
    x = node_feats, 
    edge_index=radius_graph_edge_index, 
    y = labels
)

In [31]:
import networkx
from networkx import connected_components

In [None]:
def get_disconnected_components(data:Data):
    # Convert PyG graph to NetworkX graph
    G = to_networkx(data, to_undirected=True)
    
    # Get connected components
    component_list = list(connected_components(G))
    
    # Convert node indices back to PyTorch tensors
    components = [torch.tensor(list(component)) for component in component_list]
    
    return components

example_components = get_disconnected_components(example_graph_data)

In [33]:
len( example_components ),torch.unique(labels).shape 

(60676, torch.Size([9388]))

In [34]:
# get one of these component for metric analysis 
example_component = example_components[0]
print(example_component.shape)

# find the majority particle id that belong to this track: 
component_labels = labels[example_component]
mode , count = torch.mode( component_labels )
mode , count 

torch.Size([1])


(tensor(288251816628453376), tensor(0))

In [35]:
# accociateed particle with each disconnectde component of the graph : 
accociated_particle = [ torch.mode(labels[component])[0] for component in example_components]
len(accociated_particle) , len(example_components)

(60676, 60676)

### 6. Define Track Metrics

for detais of the metrices used, see [1]. 

[1] Ju, Xiangyang, et al. "Performance of a geometric deep learning pipeline for HL-LHC particle tracking." The European Physical Journal C 81 (2021): page 5.

In [None]:
def event_reconstruction_metrics( graph_data:Data , labels:Tensor )->tuple: 
    '''
        data : graph dataset to get the metrics for : 
        labels : particle_ids for each of the nodes (hits)
        returns : the trajecteory and particle purity of the graph dataset. 
    '''
    
    # get the disconnected compoenents (reconstructed tracks) form the graph steructure : 
    disconnected_components = get_disconnected_components(graph_data)
    
    # del graph_data 
    ## 1. get the matched partices for each of the disconnecte graphs : 
    
    #  partices with the max occourence for each of the disconnected components : 
    max_track_particle = torch.tensor([ torch.mode(labels[component])[0] for component in disconnected_components])
    # print( max_track_particle )
    # frequency of the particle with the most occourence for each of the disconnected component : 
    max_track_particle_freq = torch.tensor([ torch.sum(labels[track] == particle) for particle,track in zip(max_track_particle,disconnected_components)])
    # print( max_track_particle_freq )
    # get the number of hits that belong to each reconstructed track : 
    num_hits_tracks = torch.tensor( [component.shape[0] for component in disconnected_components] )
    # print( num_hits_tracks )
    # keep mark for each reconstructed track for which the particle 
    # with the most occourance, makes for at least 50% of all the hits in 
    # the given reconstructed track. 
    # we also exclude particle's with id 0 
    
    mask = ( 2*max_track_particle_freq >= num_hits_tracks ) & ( max_track_particle != 0 )
    
    # total number of true hits left by the underlying max_track_particle : 
    max_track_particle_num_true_hits = torch.tensor([torch.sum(labels==particle) for particle in max_track_particle])
    
    # update the mark, now slelecting the tracks such that at
    # least 50% of the max occurance particles true hits
    # must be contained in the reconstructed graph. 
    mask = mask & ( 2*max_track_particle_freq >= max_track_particle_num_true_hits )
    
    # print(f'Toal Number of Matched Particles : {torch.sum(mask)}' )
    
    # with this we can how get the matched reconstructed 
    # tracks and corresponding matched particles 
    matched_tracks = [ tracks for tracks,matched in zip(disconnected_components,mask) if matched ]
    matched_particles = max_track_particle[mask]


    # 2. Get the Track Purity : 
    track_purity  = torch.mean(max_track_particle_freq/num_hits_tracks)
    # print( max_track_particle_freq/num_hits_tracks ) 
    # 3. Get Particle Purity : 
    particle_purity = torch.mean(max_track_particle_freq/max_track_particle_num_true_hits)
    # print( max_track_particle_freq/max_track_particle_num_true_hits )
    
    # # # 2.  tracking efficiency metric
    
    # # get the total number of unique particles / true tracks : 
    # num_particles = torch.sum( torch.unique(labels) != 0 )
    # # get the tracking efficiency : 
    # tracking_efficiency = torch.sum( torch.unique(matched_particles)  != 0 )/num_particles
    
    # # # 3. tracking purity metric 
    
    # # get the total number of reconstructed tracks : 
    # num_reconstucted_tracks = torch.sum( max_track_particle != 0 )
    # # get the tracking purity : 
    # if num_reconstucted_tracks == 0 : 
    #     tracking_purity = torch.tensor([0.0])
    # else : 
    #     tracking_purity = torch.sum(mask)/num_reconstucted_tracks
    
    # # return the metrics : 
    
    return  track_purity , particle_purity 

In [37]:
event_reconstruction_metrics( example_graph_data , labels )

(tensor(0.9091), tensor(0.0749))

### 7. Define Train Test and Validation Dataloders 

In [38]:
def train_test_split(
    dataset:PointCloudData,
    valid_size:float,
    test_size:float,
    num_works:int=4
):
    '''
    valid_size : amount of data to reserve for validation (normalized to 1 )
    test_size : amount of data to reserve for testing (normalized to 1 )
    Returns : train/validation/test data loders. 
    '''
    
    train_size=1-test_size-valid_size
    
    if not ( (train_size <= 1.) & (valid_size <= 1.) & (test_size <= 1. )) : 
        raise ValueError('Improper valid/train size encountered.')
    
    # total number of events : 
    num_events = len(dataset)
    
    # get shuffeled indices 
    indices = list(range(num_events))
    np.random.shuffle(indices)
    train_split = int(np.floor(train_size * num_events))
    valid_split = int(np.floor(valid_size * num_events))
    
    train_index, valid_index, test_index = indices[0:train_split], indices[train_split:train_split + valid_split], indices[train_split + valid_split:]
    
    # define samplers for obtaining training and validation batches
    train_sampler = SubsetRandomSampler(train_index)
    valid_sampler = SubsetRandomSampler(valid_index)
    test_sampler = SubsetRandomSampler(test_index)
    
    # define data loaders : 
    train_loder = DataLoader(
        dataset=dataset, 
        batch_size=1, 
        num_workers=num_works, 
        sampler = train_sampler,
        persistent_workers=True if num_works > 0 else False
    )
    
    valid_loder = DataLoader(
        dataset=dataset, 
        batch_size=1, 
        num_workers=num_works, 
        sampler = valid_sampler,
        persistent_workers=True if num_works > 0 else False
    )
    
    test_loder = DataLoader(
        dataset=dataset, 
        batch_size=1, 
        num_workers=num_works, 
        sampler = test_sampler,
        persistent_workers=True if num_works > 0 else False
    )
    
    # return data loders : 
    return train_loder,valid_loder,test_loder

### 8. Training and Testing 

here we intend to define the training and testing flow of the algorithm 

In [39]:
def TrainEmbedding(
    model:EmbeddingModel,
    train_loder:DataLoader, 
    lr:float=0.01,
    margin:float=0.01,
    max_num_neighbors:int=100
): 
    
    # initialize optimizer : 
    optimizer = torch.optim.SGD(model.parameters(),lr = lr)
    
    # initialize train loss
    train_loss = 0.0 
    
    # total number of events : 
    num_events = len(train_loder)
    
    # loop  over the training dataset 
    for i,(event_data,labels) in tqdm(enumerate(train_loder), bar_format='{l_bar}{bar}| Event {n_fmt}/{total_fmt} [{elapsed}<{remaining}, ' '{rate_fmt}{postfix}]' , total = len(train_loder) , ncols = 75) : 

        optimizer.zero_grad()
        
        output = model(event_data.squeeze_(dim=0))
        
        radius_graph_edge_index = torch_geometric.nn.pool.radius_graph(
            output, 
            r = margin , 
            loop = False , 
            max_num_neighbors=max_num_neighbors
        )
        
        loss = EmbeddingLossFunction(
            x = output,
            labels=labels.squeeze_(dim=0),
            radius_graph_edge_index=radius_graph_edge_index, 
            margin=margin
        )
        loss.backward()
        
        optimizer.step()
        
        train_loss += loss.item()
        
    model.to('cpu')
        
    return train_loss/num_events


def TestEmbedding(
    model:EmbeddingModel, 
    test_loder:DataLoader, 
    margin:float=0.01, 
    max_num_neighbors:int=100
):
    # initialize loss , putity and efficiency 
    test_loss , test_track_efficiency , test_track_purity = 0.0 , 0.0 , 0.0 
    
    # get the number of events : 
    num_events = len(test_loder)
    
    model.eval()
    
    # loop  over the training dataset 
    for i,(event_data,labels) in tqdm(enumerate(test_loder), bar_format='{l_bar}{bar}| Event {n_fmt}/{total_fmt} [{elapsed}<{remaining}, ' '{rate_fmt}{postfix}]' , total = len(test_loder) , ncols = 75 ) :  
        
        output = model(event_data.squeeze_(dim=0))
        
        # radius graph edges 
        radius_graph_edge_index = torch_geometric.nn.pool.radius_graph(
            output, 
            r = margin , 
            loop = False , 
            max_num_neighbors=max_num_neighbors
        )
        
        # get the loss : 
        loss = EmbeddingLossFunction(
            x = output, 
            labels = labels.squeeze_(dim=0),
            radius_graph_edge_index=radius_graph_edge_index, 
            margin=margin
        )
        test_loss += loss.item()
        
        # create the graph data structure : 
        event_graph_data = Data(
            x = event_data, 
            edge_index=radius_graph_edge_index, 
            y = labels
        )
        
        # get the track performance metrics : 
        efficiency , purity = event_reconstruction_metrics(event_graph_data, labels )
        test_track_efficiency += efficiency.item() 
        test_track_purity += purity.item()
        
    return test_loss/num_events , test_track_efficiency/num_events , test_track_purity/num_events

define a training loop for a given number of epoch's 

In [40]:
def TrainEmbeddingModel(
    model:EmbeddingModel, 
    train_loder:DataLoader, 
    valid_loder:DataLoader, 
    lr:float=0.01,
    margin:float=0.01, 
    max_num_neighbors:int=100,
    n_epochs:int=15, 
    save_model:bool=True, 
    save_model_path:str=None
): 
    if save_model and save_model_path==None: 
        raise ValueError('Must pass a valid path to save the model.')
    
    # initialize tracker for minimum validation loss
    valid_loss_min = np.inf 
    # initialize tracker for validation metrices and losses 
    train_loss , valid_loss , valid_track_efficiency , valid_track_purity = [] , [] , [] , []
    
    # loop throuth the training process n_epoch times 
    for epoch in range(1,n_epochs+1):
        
        print(f'----------------------Starting Epoch {epoch}----------------------')
        
        print('Begin Training: ')
        # training step : 
        train_loss.append(
            TrainEmbedding(
                model=model, 
                train_loder=train_loder, 
                lr = lr , 
                margin = margin , 
                max_num_neighbors=max_num_neighbors
            )
        )
        
        print('Begin Validation: ')
        # validation step : 
        (
            epoch_valid_loss, 
            epoch_valid_track_efficiency, 
            epoch_valid_track_purity
        ) = TestEmbedding(
            model = model , 
            test_loder= valid_loder, 
            margin = margin , 
            max_num_neighbors=max_num_neighbors
        )
        valid_loss.append(epoch_valid_loss)
        valid_track_efficiency.append(epoch_valid_track_efficiency)
        valid_track_purity.append(epoch_valid_track_purity)
        
        print('Training Loss: {:.6f} \nValidation Loss: {:.6f} \nTrack Purity: {:.6f} \nParticle Purity: {:.6f}'.format(
            train_loss[-1],
            valid_loss[-1],
            valid_track_efficiency[-1], 
            valid_track_purity[-1]
        ))
    
        # save model if validation loss has decreased
        if valid_loss[-1] <= valid_loss_min:
            print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
                valid_loss_min,
                valid_loss[-1]
            ))
            torch.save(model.state_dict(), save_model_path)
            valid_loss_min = valid_loss[-1]  
    
    return (
        list(range(1,n_epochs+1)),
        train_loss, 
        valid_loss, 
        valid_track_efficiency, 
        valid_track_purity
    )

### 9. Example Running of the Model: 

In [41]:
torch.nn.BatchNorm1d(5)( example_model_out).shape 

torch.Size([120246, 5])

In [42]:
# get the point cloud data 
event_dataset = PointCloudData(
    dataset_path='../data/train_100_events/',
    detector_path='../data/detectors.csv'
)

# get the data loders : 
(
    train_loder, 
    valid_loder, 
    test_loder
) = train_test_split(
    dataset=event_dataset,
    valid_size=0.1, 
    test_size=0.1,
    num_works=0
)

# initialize the model : 
embedding_model = EmbeddingModel(
    in_features=15, 
    hidden_features=[1024,512,256,128,64,32,16],
    out_features=5
)

save_model_path = '../data/models/embedding_model.pt'

# # train the model : 
# (
#     epochs , train_loss , 
#     valid_loss , valid_efficiency, 
#     valid_purity 
# ) = TrainEmbeddingModel(
#     model = embedding_model, 
#     train_loder=train_loder, 
#     valid_loder=valid_loder, 
#     save_model_path=save_model_path, 
#     margin = 0.01 
# )

In [43]:
train_loder

<torch_geometric.loader.dataloader.DataLoader at 0x138fcf4f0>

In [44]:
inputs, classes = next(iter(train_loder))

In [45]:
inputs.shape 

torch.Size([1, 106942, 15])

In [46]:
# out = embedding_model(inputs)

### 10. Conver the pytorch model into a Pytorch Lightning Model 

In [47]:
import pytorch_lightning as pl
from pytorch_lightning import LightningModule
from torchmetrics import Metric

In [48]:
import yaml 

In [49]:
with open('03-Embedding-Hyperparameters.yml' , 'r' ) as f : 
    hparams = yaml.safe_load(f)

In [50]:
# from TrackML.Embedding.utils import EmbeddingModel,EmbeddingLossFunction,event_reconstruction_metrics,collate_function

In [51]:
from TrackML.Embedding.utils import PointCloudData

In [52]:
type( hparams )

dict

In [53]:
hparams

{'detector_path': '../data/detectors.csv',
 'dataset_path': '../data/train_100_events/',
 'min_hits': 3,
 'shuffle': True,
 'valid_size': 0.1,
 'test_size': 0.1,
 'num_works': 8,
 'in_featuers': 15,
 'hidden_featuers': [1024, 512, 256, 128, 64, 32, 16],
 'out_featuers': 5,
 'margin': 0.1,
 'max_num_neighbours': 100,
 'lr': 0.01,
 'save_model': True,
 'save_model_path': '../data/models/embedding_model.pt'}

In [54]:
class EmbeddingDataset(pl.LightningDataModule): 
    
    # initialize the class : 
    def __init__(self,hparams)->None: 
        super().__init__() 
        self.save_hyperparameters(hparams)
        
    def prepare_data(self)->None: 
        self.detector = Preprocessing.load_detector_data(self.hparams['detector_path'])
        # get the list of event ids from the dataset folder : 
        self.eventids = [ code[:-9] for code in os.listdir(self.hparams['dataset_path']) if code.endswith('-hits.csv') ]
        dataset = PointCloudData(dataset_path=self.hparams['dataset_path'] , detector_path=self.hparams['detector_path'] , min_nhits=self.hparams['min_hits'] )
        self.train_ds , self.val_ds , self.test_ds = train_test_split(
            dataset=dataset, valid_size=self.hparams['valid_size'], 
            test_size=self.hparams['test_size'], num_works=self.hparams['num_works']
        )
    
    def train_dataloader(self): 
        return self.train_ds 
    def val_dataloader(self): 
        return self.val_ds 
    def test_dataloader(self): 
        return self.test_ds

In [55]:
# define custom log loss metric for pytorch lightning : 
def _custom_dist_reduce_fn(x): 
    return torch.log( torch.sum( x , dim = 0 ) )

class LogSumLoss(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("log_loss_sum", default=torch.tensor(0.0), dist_reduce_fx=_custom_dist_reduce_fn)

    def update(self, loss):
        self.log_loss_sum = torch.log(loss) + torch.log( 1 + torch.exp(self.log_loss_sum)/loss)

    def compute(self):
        return self.log_loss_sum

In [56]:
class Purity(Metric):
    def __init__(self): 
        super().__init__()
        self.add_state(
            'intersections', 
            default = torch.tensor([]), 
            dist_reduce_fx = 'cat'
        )
        self.add_state(
            'num_hits', 
            default = torch.tensor([]),
            dist_reduce_fx = 'cat'
        )
        
    def update(self,intersections,num_hits):
        self.intersections = intersections
        self.num_hits = num_hits
    
    def compute(self): 
        return torch.mean( self.intersections/self.num_hits )
    

In [57]:
def event_reconstruction_metrics( graph_data:Data , labels:Tensor )->tuple: 
    '''
        data : graph dataset to get the metrics for : 
        labels : particle_ids for each of the nodes (hits)
        returns : the trajecteory and particle purity of the graph dataset. 
    '''
    
    # get the disconnected compoenents (reconstructed tracks) form the graph steructure : 
    disconnected_components = get_disconnected_components(graph_data)
    
    
    ## 1. get the matched partices for each of the disconnecte graphs : 
    
    #  partices with the max occourence for each of the disconnected components : 
    max_track_particle = torch.tensor([ torch.mode(labels[component])[0] for component in disconnected_components])
    # print( max_track_particle )
    # frequency of the particle with the most occourence for each of the disconnected component : 
    max_track_particle_freq = torch.tensor([ torch.sum(labels[track] == particle) for particle,track in zip(max_track_particle,disconnected_components)])
    # print( max_track_particle_freq )
    # get the number of hits that belong to each reconstructed track : 
    num_hits_tracks = torch.tensor( [component.shape[0] for component in disconnected_components] )
    # total number of true hits left by the underlying max_track_particle : 
    max_track_particle_num_true_hits = torch.tensor([torch.sum(labels==particle) for particle in max_track_particle])
    
    return  max_track_particle_freq , num_hits_tracks , max_track_particle_num_true_hits

In [58]:
class EmbeddingBase(LightningModule):
    
    # initialize the Embedding Class 
    def __init__(self , hparams ): 
        super().__init__()
        
        # save the hypermeters : 
        self.save_hyperparameters(hparams)
        
        # loss accumulation metric : 
        self.log_sum_loss = LogSumLoss()
        # track and particle purity : 
        self.particle_purity = Purity() 
        self.track_purity = Purity()
        
        # save the model descriptions : 
        self.model = EmbeddingModel(
            in_features=hparams['in_featuers'] , 
            hidden_features=hparams['hidden_featuers'], 
            out_features=hparams['out_featuers']
        )
    
    
    # forward function 
    def forward(self , inputs ): 
        return self.model( inputs )
    
    
    # training logic : 
    def training_step(self , batch , batch_idx ): 
        
        event_data , labels = batch 
        event_data.squeeze_(dim=0)
        labels.squeeze_(dim=0)
        
        output = self( event_data )
        
        radius_graph_edge_index = torch_geometric.nn.pool.radius_graph(
            output.to(device = 'cpu'), 
            r = self.hparams['margin'] , 
            loop = False , 
            max_num_neighbors=self.hparams['max_num_neighbours']
        )
        radius_graph_edge_index = radius_graph_edge_index.to( device = self.device )
        loss = EmbeddingLossFunction(
            x = output,
            labels=labels.squeeze_(dim=0),
            radius_graph_edge_index=radius_graph_edge_index, 
            margin=self.hparams['margin']
        )
        self.log(
            'Loss' , loss, 
            prog_bar = True , on_step = True , on_epoch = False 
        )
        
        self.log_sum_loss.update(loss)
        self.log(
            'Log Loss' , self.log_sum_loss.compute(), on_step = False , 
            on_epoch = True , prog_bar = True , reduce_fx = 'max' 
        )
        
        self.log(
            'Memory Allocated' , torch.mps.current_allocated_memory()/(1024**3), 
            prog_bar=True , on_step = True , on_epoch=True, 
            reduce_fx='max'
        )
        
        return loss 
    
    # common logic for test and validation 
    def _test_val_common_step_(self , batch , batch_idx ): 
        
        with torch.no_grad(): 
            event_data , labels = batch 
            event_data.squeeze_(dim=0)
            labels.squeeze_(dim=0)
            
            output = self( event_data )
            
            radius_graph_edge_index = torch_geometric.nn.pool.radius_graph(
                output.to(device = 'cpu'), 
                r = self.hparams['margin'] , 
                loop = False , 
                max_num_neighbors=self.hparams['max_num_neighbours']
            )
            # print(event_data.device)
            radius_graph_edge_index = radius_graph_edge_index.to(event_data.device)
            loss = EmbeddingLossFunction(
                x = output,
                labels=labels.squeeze_(dim=0),
                radius_graph_edge_index=radius_graph_edge_index, 
                margin=self.hparams['margin']
            )
            self.log(
                'Loss',loss, 
                prog_bar = True , 
                on_step = True , 
                on_epoch = False 
            )
            self.log_sum_loss.update(loss)
            self.log(
                'Log Loss' , self.log_sum_loss.compute(), 
                on_step = False , on_epoch = True , 
                prog_bar = True , reduce_fx = 'max' 
            )
            # create the graph data structure : 
            event_graph_data = Data(
                x = event_data, 
                edge_index=radius_graph_edge_index, 
                y = labels
            )
            intersections , num_track , num_particle = event_reconstruction_metrics(event_graph_data,labels)
            self.track_purity.update( intersections = intersections , num_hits = num_track )
            self.particle_purity.update( intersections = intersections , num_hits = num_particle )
            self.log_dict(
                {'track purity' : self.track_purity.compute() , 'particle purity' : self.particle_purity.compute() }, 
                on_step = False , on_epoch = True , prog_bar = True , reduce_fx = 'mean'
            )
            self.log(
                'Memory Allocated' , torch.mps.current_allocated_memory()/(1024**3), 
                prog_bar=True , on_step = True , on_epoch=True, 
                reduce_fx='max'
            )
        return loss
    
    def validation_step(self,batch,batch_idx):
        return self._test_val_common_step_(batch,batch_idx)
    
    def test_step(self,batch,batch_idx): 
        return self._test_val_common_step_(batch,batch_idx)
    
    # set the optimizer  :
    def configure_optimizers(self): 
        return torch.optim.SGD(self.model.parameters(),lr = self.hparams['lr'])

In [59]:
from pytorch_lightning import Trainer

In [60]:
model = EmbeddingBase(hparams)
ds = EmbeddingDataset(hparams)

from pytorch_lightning.callbacks import DeviceStatsMonitor
device_stats = DeviceStatsMonitor()

trainer = Trainer(
    accelerator = "cpu", 
    devices = "auto",
    enable_checkpointing=True, 
    fast_dev_run = 3, 
    callbacks=[device_stats]
)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/ashmitbathla/Documents/UGP-Local/TrackML/TrackMLVenv/lib/python3.10/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/Users/ashmitbathla/Documents/UGP-Local/TrackML/TrackMLVenv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enabl

In [61]:
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

In [62]:
trainer.fit(model , ds )


  | Name            | Type           | Params | Mode 
-----------------------------------------------------------
0 | log_sum_loss    | LogSumLoss     | 0      | train
1 | particle_purity | Purity         | 0      | train
2 | track_purity    | Purity         | 0      | train
3 | model           | EmbeddingModel | 720 K  | train
-----------------------------------------------------------
720 K     Trainable params
0         Non-trainable params
720 K     Total params
2.882     Total estimated model params size (MB)
27        Modules in train mode
0         Modules in eval mode


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=3` reached.
