# Filtering MLP Model 

In [1]:
import os
import time
import random
import numpy as np

from scipy.stats import ortho_group

from typing import Optional, Tuple


import torch
import torch.nn as nn
from torch.linalg import norm 
import torch.utils.data as data
import torch.nn.functional as F
from torch.nn import Linear, ReLU, BatchNorm1d, Module, Sequential
from torch import Tensor

# torch.set_default_dtype(torch.float64)

from torch_geometric.typing import (
    Adj,
    OptPairTensor,
    OptTensor,
    Size,
    SparseTensor,
    torch_sparse,
)

import torch_geometric
from torch_geometric.data import Data
from torch_geometric.data import Batch
import torch_geometric.transforms as T
from torch.utils.data.sampler import SubsetRandomSampler
from torch_geometric.utils import remove_self_loops, to_dense_adj, dense_to_sparse, is_undirected , to_undirected, contains_self_loops , to_networkx
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing, global_mean_pool, knn_graph
from torch_geometric.datasets import QM9
# from torch_scatter import scatter
# from torch_cluster import knn

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
# import uproot
import vector
vector.register_awkward()
import awkward as ak

from IPython.display import HTML

print("PyTorch version {}".format(torch.__version__))
print("PyG version {}".format(torch_geometric.__version__))

PyTorch version 2.6.0
PyG version 2.6.1


In [2]:
from TrackML.Models.utils import buildMLP
from TrackML.Embedding.dataset import PointCloudData
from TrackML.Embedding.base import EmbeddingBase

In [3]:
seed = 5 
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)

In [4]:
class PostEmbeddingGraph(data.Dataset): 
    
    # initialize the dataset class : 
    def __init__(
        self,dataset_path:str,
        detector_path:str,embd_model_path:str,
        min_nhits=3,margin:float=0.1, 
        max_num_neighbors:int=64
        )->None:
        '''
        dataset_path : path to the dataset with the events. 
        detector_path : path to the detector.csv file 
        min_nhits : upper bound on the number of  hits to keep. 
        margin : margin around which to create the radius graph 
        embd_model_path : path to the saved embedding model. 
        returns : For each event, returs the PyG graph data that 
            is constructed post embedding. 
        '''
        super().__init__()
        
        self.margin = margin 
        self.max_num_neighbors = max_num_neighbors
        
        # get the point cloud dataset: 
        self.dataset = PointCloudData(
            dataset_path=dataset_path, 
            detector_path=detector_path, 
            min_nhits=min_nhits
        )
        
        # save the embedding model 
        self.embd_model = EmbeddingBase.load_from_checkpoint(embd_model_path)
        self.embd_model.eval()
        
        
    def __len__(self)->int: 
        return len( self.dataset )
    
    def __getitem__(self, index)->Data:
        
        # get the raw node featuers and labels 
        node_feats , labels = self.dataset[index]
        # get the latent space feat of the nodes : 
        latent_feats = self.embd_model(node_feats)
        
        # get the radius graph edge index : 
        radius_graph_edge_index = torch_geometric.nn.pool.radius_graph(
            latent_feats, 
            r = self.margin , 
            loop = False , 
            max_num_neighbors=self.max_num_neighbors
        )
        
        # get the edges labels, 0 if they belong to the same track, 1 otherwise : 
        row , col = radius_graph_edge_index 
        edge_attr = (labels[row] != labels[col]).float().unsqueeze_(dim=1)
        
        # edge_purity ; 
        edge_purity = 1 - torch.sum( edge_attr )/radius_graph_edge_index.shape[1]
        
        # create the graph data structure : 
        graph_data = Data(
            x = node_feats, 
            edge_index=radius_graph_edge_index, 
            edge_attr=edge_attr, 
            y = labels , 
            edge_purity = edge_purity
        )
        
        return graph_data

In [5]:
dataset_path = '../data/train_100_events/'
detector_path= '../data/detectors.csv'
embd_model_path =  '../data/models/Embedding-Model-v7.ckpt'

In [6]:
filter_dataset = PostEmbeddingGraph(
    dataset_path , 
    detector_path , 
    embd_model_path
)

filter_data_instance = filter_dataset[10]
filter_data_instance

Data(x=[96872, 15], edge_index=[2, 6199808], edge_attr=[6199808, 1], y=[96872], edge_purity=0.0737377405166626)

In [7]:
filter_data_instance.edge_purity

tensor(0.0737)

In [8]:
import networkx
from networkx import connected_components

In [9]:
def get_disconnected_components(data:Data):
    # Convert PyG graph to NetworkX graph
    G = to_networkx(data, to_undirected=True)
    
    # Get connected components
    component_list = list(connected_components(G))
    
    # Convert node indices back to PyTorch tensors
    components = [torch.tensor(list(component)) for component in component_list]
    
    return components

example_components = get_disconnected_components(filter_data_instance)
len(example_components)

592

In [10]:
def event_reconstruction_metrics( graph_data:Data  )->tuple: 
    '''
        data : graph dataset to get the metrics for : 
        labels : particle_ids for each of the nodes (hits)
        returns : the trajecteory and particle purity of the graph dataset. 
    '''
    
    labels = graph_data.y 
    
    # get the disconnected compoenents (reconstructed tracks) 
    # form the graph steructure : 
    disconnected_components = get_disconnected_components(graph_data)
    
    #  partices with the max occourence for each of the disconnected components : 
    max_track_particle = torch.tensor([ torch.mode(labels[component])[0] for component in disconnected_components])
    # frequency of the particle with the most occourence for each of the disconnected component : 
    max_track_particle_freq = torch.tensor([ torch.sum(labels[track] == particle) for particle,track in zip(max_track_particle,disconnected_components)])
    # get the number of hits that belong to each reconstructed track : 
    num_hits_tracks = torch.tensor( [component.shape[0] for component in disconnected_components] )
    # total number of true hits left by the underlying max_track_particle : 
    max_track_particle_num_true_hits = torch.tensor([torch.sum(labels==particle) for particle in max_track_particle])
    
    # 2. Get the Track Purity : 
    track_purity  = torch.mean(max_track_particle_freq/num_hits_tracks)
    # 3. Get Particle Purity : 
    particle_purity = torch.mean(max_track_particle_freq/max_track_particle_num_true_hits)
    
    return  track_purity , particle_purity 

In [11]:
event_reconstruction_metrics( filter_data_instance )

(tensor(0.2034), tensor(0.0157))

In [None]:
class FilteringModel(nn.Module): 
    
    def __init__(
        self, 
        in_features:int, 
        hidden_features:list, 
    )->None:
        
        super().__init__()
        
        # the acctual in featuers would be twice the the 
        # number of acctual node features since a edge feature 
        # is a concatination of two node featuers 
        
        in_features *= 2 
        # define the MLP layer : 
        self.MLP = buildMLP(
            insize=in_features, 
            outsize=1, 
            features=hidden_features,
            add_bnorm=True,
            add_activation=None
        )
        
    def forward(self,data:Data)->Tensor: 
        row , col = data.edge_index 
        edge_feats = torch.cat((data.x[row] , data.x[col]) , dim  = 1 )
        return self.MLP( edge_feats )

In [13]:
filter_model_instance = FilteringModel(in_features=15,hidden_features= [20,10,5])
filter_model_out = filter_model_instance(filter_data_instance)
filter_model_out.shape 

torch.Size([6199808, 1])

In [14]:
def train_test_split(
    dataset:PostEmbeddingGraph,
    valid_size:float,
    test_size:float,
    num_works:int=4
):
    '''
    valid_size : amount of data to reserve for validation (normalized to 1 )
    test_size : amount of data to reserve for testing (normalized to 1 )
    Returns : train/validation/test data loders. 
    '''
    
    train_size=1-test_size-valid_size
    
    if not ( (train_size <= 1.) & (valid_size <= 1.) & (test_size <= 1. )) : 
        raise ValueError('Improper valid/train size encountered.')
    
    # total number of events : 
    num_events = len(dataset)
    
    # get shuffeled indices 
    indices = list(range(num_events))
    np.random.shuffle(indices)
    train_split = int(np.floor(train_size * num_events))
    valid_split = int(np.floor(valid_size * num_events))
    
    train_index, valid_index, test_index = indices[0:train_split], indices[train_split:train_split + valid_split], indices[train_split + valid_split:]
    
    # define samplers for obtaining training and validation batches
    train_sampler = SubsetRandomSampler(train_index)
    valid_sampler = SubsetRandomSampler(valid_index)
    test_sampler = SubsetRandomSampler(test_index)
    
    # define data loaders : 
    train_loder = DataLoader(
        dataset=dataset, 
        batch_size=1, 
        num_workers=num_works, 
        sampler = train_sampler,
        persistent_workers=True if num_works > 0 else False
    )
    
    valid_loder = DataLoader(
        dataset=dataset, 
        batch_size=1, 
        num_workers=num_works, 
        sampler = valid_sampler,
        persistent_workers=True if num_works > 0 else False
    )
    
    test_loder = DataLoader(
        dataset=dataset, 
        batch_size=1, 
        num_workers=num_works, 
        sampler = test_sampler,
        persistent_workers=True if num_works > 0 else False
    )
    
    # return data loders : 
    return train_loder,valid_loder,test_loder

In [15]:
train_loder_instance , val_loder_instance , _   = train_test_split(
    filter_dataset , 
    0.2 , 0.1 , 0 
)
len( train_loder_instance )

70

In [16]:
y = next(iter(train_loder_instance))
y

DataBatch(x=[104514, 15], edge_index=[2, 6688896], edge_attr=[6688896, 1], y=[104514], edge_purity=[1], batch=[104514], ptr=[2])

In [17]:
from tqdm import tqdm 

In [18]:
def TrainFiltering(
    model:FilteringModel,
    train_loder:DataLoader, 
    lr:float=0.01
): 
    
    # initialize optimizer : 
    optimizer = torch.optim.SGD(model.parameters(),lr = lr)
    # initialize train loss
    train_loss = 0.0 
    # total number of events : 
    num_events = len(train_loder)
    
    # loop  over the training dataset 
    for i,graph_data in tqdm(enumerate(train_loder), bar_format='{l_bar}{bar}| Event {n_fmt}/{total_fmt} [{elapsed}<{remaining}, ' '{rate_fmt}{postfix}]' , total = len(train_loder) , ncols = 75) : 

        optimizer.zero_grad()
        output = model(graph_data)
        
        loss_fn = nn.BCEWithLogitsLoss(
            reduction='sum' , pos_weight=graph_data.edge_purity/( 1-graph_data.edge_purity)
        )
        loss = loss_fn(output,graph_data.edge_attr)
        loss.backward()
        
        optimizer.step()
        train_loss += loss.item()
        
        if( i == 2 ) : 
            break 
        
    return train_loss/num_events


def TestEmbedding(
    model:PostEmbeddingGraph, 
    test_loder:DataLoader
):
    
    train_loss = 0.0 
    
    num_events = len(test_loder)
    model.eval()
    
    # loop  over the training dataset 
    for i,graph_data in tqdm(enumerate(test_loder), bar_format='{l_bar}{bar}| Event {n_fmt}/{total_fmt} [{elapsed}<{remaining}, ' '{rate_fmt}{postfix}]' , total = len(test_loder) , ncols = 75 ) :  

        output = model(graph_data)
        
        loss_fn = nn.BCEWithLogitsLoss(
            reduction='sum' , pos_weight=graph_data.edge_purity/( 1-graph_data.edge_purity)
        )
        loss = loss_fn(output,graph_data.edge_attr)
        train_loss += loss.item()
        
        if( i == 2 ) : 
            break
        
    return train_loss/num_events

In [19]:
# TrainFiltering(filter_model_instance,train_loder_instance)

In [20]:
# TestEmbedding(filter_model_instance,val_loder_instance)

In [96]:
from TrackML.Filtering.utils import PostEmbeddingGraph

In [97]:
import pytorch_lightning as pl
from pytorch_lightning import LightningModule
from torchmetrics import Metric
import yaml 

In [98]:
with open('04-Filtering-Hyperparameters.yml' , 'r' ) as f : 
    hparams = yaml.safe_load(f)

In [99]:
class FilteringDataset(pl.LightningDataModule): 
    def __init__(self,hparams:dict)->None: 
        super().__init__()
        self.save_hyperparameters(hparams)
        
    def setup(self,stage=None): 
        dataset = PostEmbeddingGraph(
            dataset_path=self.hparams['dataset_path'], 
            detector_path=self.hparams['detector_path'], 
            embd_model_path=self.hparams['embd_model_path'], 
            min_nhits = self.hparams['min_nhits'], 
            margin = self.hparams['margin'] , 
            max_num_neighbors= self.hparams['max_num_neighbours']
        )
        self.train_ds , self.val_ds , self.test_ds = train_test_split(
            dataset=dataset, 
            valid_size=self.hparams['valid_size'], 
            test_size=self.hparams['test_size'] , 
            num_works=self.hparams['num_works']
        )
    
    def train_dataloader(self): 
        return self.train_ds 
    def val_dataloader(self): 
        return self.val_ds 
    def test_dataloader(self): 
        return self.test_ds

In [100]:
from torchmetrics.classification import BinaryPrecision, BinaryRecall, BinaryF1Score

In [101]:
class LogSumLoss(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("log_losses", default=torch.tensor([]), dist_reduce_fx="cat")

    def update(self, loss):
        # Expect loss to be a tensor of shape (N,) or scalar
        loss = loss.flatten().detach()
        self.log_losses = torch.cat([self.log_losses, torch.log(loss)])

    def compute(self):
        # log-sum-exp: log(sum(exp(x))) = max + log(sum(exp(x - max)))
        if self.log_losses.numel() == 0:
            return torch.tensor(float("-inf"), device=self.log_losses.device)

        max_log = torch.max(self.log_losses)
        return max_log + torch.log(torch.sum(torch.exp(self.log_losses - max_log)))

In [117]:
class FilteringModelPl(LightningModule): 
    
    def __init__(self,hparams): 
        super().__init__()
        self.save_hyperparameters(hparams)
        
        # Metrics (with DDP support)
        self.train_precision = BinaryPrecision()
        self.train_recall = BinaryRecall()
        self.train_f1 = BinaryF1Score()

        self.val_precision = BinaryPrecision()
        self.val_recall = BinaryRecall()
        self.val_f1 = BinaryF1Score()
        
        self.log_sum_loss = LogSumLoss()
        
        self.model = FilteringModel(
            in_features=hparams['in_featuers'],
            hidden_features=hparams['hidden_features']
        )
        
    def forward(self , x ): 
        return self.model( x )
    
    def training_step(self, batch , batch_idx ):
        print( batch )
        output = self( batch )
        
        loss_fn = nn.BCEWithLogitsLoss(
            reduction='sum' , pos_weight=batch.edge_purity/( 1-batch.edge_purity)
        )
        loss = loss_fn(output,batch.edge_attr)
        
        with torch.no_grad() : 
            preds = torch.nn.Sigmoid()(output)
            self.train_precision.update( preds , batch.edge_attr )
            self.train_recall.update(preds, batch.edge_attr)
            self.train_f1.update(preds, batch.edge_attr)
            
        if torch.cuda.is_available(): 
                self.log(
                    'Memory Allocated' , torch.cuda.memory_allocated()/(1024**3), 
                    prog_bar=True , on_step = True , on_epoch=True, 
                    reduce_fx='max' , sync_dist=True
                )
        
        self.log('loss' , loss , on_step = True , on_epoch = False , prog_bar=True )
        self.last_idx = batch_idx
        return loss
    
    def on_train_epoch_end(self):
        self.log("train_precision", self.train_precision.compute(), prog_bar=True,sync_dist=True)
        self.log("train_recall", self.train_recall.compute(), prog_bar=True ,sync_dist=True)
        self.log("train_f1", self.train_f1.compute(), prog_bar=True,sync_dist=True)
        
        self.train_precision.reset()
        self.train_recall.reset()
        self.train_f1.reset()
        
        if self.last_idx % 10 == 0 and torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect() 
    
    def  _test_val_common_step_( self , batch , batch_idx ): 
        with torch.no_grad() : 
            output = self( batch )
        
            loss_fn = nn.BCEWithLogitsLoss(
                reduction='sum' , pos_weight=batch.edge_purity/( 1-batch.edge_purity)
            )
            loss = loss_fn(output,batch.edge_attr)
            self.log_sum_loss.update( loss )
            self.log('loss' , loss , on_step = True , on_epoch = False , prog_bar=True )     
            
            if torch.cuda.is_available(): 
                self.log(
                    'Memory Allocated' , torch.cuda.memory_allocated()/(1024**3), 
                    prog_bar=True , on_step = True , on_epoch=True, 
                    reduce_fx='max' , sync_dist=True
                )       
            
            preds = torch.nn.Sigmoid()(output)
            
            self.val_precision.update(preds, batch.edge_attr)
            self.val_recall.update(preds, batch.edge_attr)
            self.val_f1.update(preds, batch.edge_attr)
            
            self.last_idx = batch_idx
            
            return loss 

    def validation_step(self,batch,batch_idx):
        return self._test_val_common_step_(batch,batch_idx)

    def test_step(self,batch,batch_idx): 
        return self._test_val_common_step_(batch,batch_idx)
        
    def on_validation_epoch_end(self):
        self.log("val_precision", self.val_precision.compute(), prog_bar=True ,sync_dist=True)
        self.log("val_recall", self.val_recall.compute(), prog_bar=True ,sync_dist=True)
        self.log("val_f1", self.val_f1.compute(), prog_bar=True ,sync_dist=True)

        self.val_precision.reset()
        self.val_recall.reset()
        self.val_f1.reset()
        self.log_sum_loss.reset()
        
        if self.last_idx % 10 == 0 and torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
        
    def configure_optimizers(self): 
        return torch.optim.SGD(self.model.parameters(),lr = self.hparams['lr'])

In [118]:
from pytorch_lightning import Trainer

In [119]:
model = FilteringModelPl(hparams)
ds = FilteringDataset(hparams)

from pytorch_lightning.callbacks import DeviceStatsMonitor
device_stats = DeviceStatsMonitor()

trainer = Trainer(
    accelerator = "cpu", 
    devices = "auto",
    enable_checkpointing=False, 
    fast_dev_run = 1, 
    callbacks=[device_stats]
)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/ashmitbathla/Documents/UGP-Local/TrackML/TrackMLVenv/lib/python3.10/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.


In [120]:
if __name__ == "__main__":
    trainer.fit(model , ds )


  | Name            | Type            | Params | Mode 
------------------------------------------------------------
0 | train_precision | BinaryPrecision | 0      | train
1 | train_recall    | BinaryRecall    | 0      | train
2 | train_f1        | BinaryF1Score   | 0      | train
3 | val_precision   | BinaryPrecision | 0      | train
4 | val_recall      | BinaryRecall    | 0      | train
5 | val_f1          | BinaryF1Score   | 0      | train
6 | log_sum_loss    | LogSumLoss      | 0      | train
7 | model           | FilteringModel  | 951    | train
------------------------------------------------------------
951       Trainable params
0         Non-trainable params
951       Total params
0.004     Total estimated model params size (MB)
18        Modules in train mode
0         Modules in eval mode
/Users/ashmitbathla/Documents/UGP-Local/TrackML/TrackMLVenv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have man

Training: |          | 0/? [00:00<?, ?it/s]

DataBatch(x=[116980, 15], edge_index=[2, 11698000], edge_attr=[11698000, 1], y=[116980], edge_purity=[1], batch=[116980], ptr=[2])


Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=1` reached.
