In [1]:
import os
import time
import random
import numpy as np

from scipy.stats import ortho_group

from typing import Optional, Tuple

import torch
import torch.nn as nn
from torch.linalg import norm 
import torch.utils.data as data
import torch.nn.functional as F
from torch.nn import Linear, ReLU, BatchNorm1d, Module, Sequential
from torch import Tensor

# torch.set_default_dtype(torch.float64)

from torch_geometric.typing import (
    Adj,
    OptPairTensor,
    OptTensor,
    Size,
    SparseTensor,
    torch_sparse,
)

import torch_geometric
from torch_geometric.data import Data
from torch_geometric.data import Batch
import torch_geometric.transforms as T
from torch.utils.data.sampler import SubsetRandomSampler
from torch_geometric.utils import remove_self_loops, to_dense_adj, dense_to_sparse, is_undirected , to_undirected, contains_self_loops , to_networkx , softmax 
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing, global_mean_pool, knn_graph , HypergraphConv
from torch_geometric.datasets import QM9
from torch_scatter import scatter_add
# from torch_scatter import scatter
# from torch_cluster import knn

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
# import uproot
import vector
vector.register_awkward()
import awkward as ak

from IPython.display import HTML

print("PyTorch version {}".format(torch.__version__))
print("PyG version {}".format(torch_geometric.__version__))

PyTorch version 2.6.0
PyG version 2.6.1


In [2]:
from TrackML.Embedding.dataset import PointCloudData
from TrackML.Models.utils import buildMLP 

In [3]:
seed = 5 
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)

In [4]:
def generate_hyperedges(node_features: Tensor, radius_threshold: float, batch_size: int = 32) -> Tuple[Tensor, int]:
    """
    Generates hyperedge data for input to a hyperconv layer using the radius cluster algorithm.

    Args:
        node_features (Tensor): Node features as a PyTorch tensor of shape [num_nodes, num_features].
        radius_threshold (float): Radius threshold for clustering nodes into hyperedges.
        batch_size (int): Batch size for processing distances.

    Returns:
        Tuple[Tensor, int]: A tuple containing the edge_index tensor and the number of hyperedges.
    """
    num_nodes = node_features.size(0)
    hyperedges = []

    # Generate hyperedges using batching
    for i in range(0, num_nodes, batch_size):
        batch_end = min(i + batch_size, num_nodes)
        batch_features = node_features[i:batch_end]  # Shape: [batch_size, num_features]

        # Compute distances for the batch
        distances = torch.cdist(batch_features, node_features)  # Shape: [batch_size, num_nodes]

        # Identify nodes within the radius threshold
        for j, distance_row in enumerate(distances):
            hyperedge = torch.where(distance_row <= radius_threshold)[0]
            hyperedges.append(hyperedge)

    # Convert hyperedges to edge_index format
    edge_index = []
    for idx, hyperedge in enumerate(hyperedges):
        for node in hyperedge:
            edge_index.append([node.item(), idx])

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

    return edge_index, len(hyperedges)

In [5]:

num_nodes = 100
num_features = 3
node_features = torch.rand((num_nodes, num_features))
radius_threshold = 0.1

# Example usage
edge_index , num_hg = generate_hyperedges(node_features, radius_threshold)
print(f"Number of hyperedges: {num_hg}")
print(f"Edge index shape: {edge_index.shape}")

Number of hyperedges: 100
Edge index shape: torch.Size([2, 142])


In [6]:
print(edge_index[  : , :10  ] )

tensor([[ 0,  1, 56, 76,  2,  3,  4,  5,  6,  7],
        [ 0,  1,  1,  1,  2,  3,  4,  5,  6,  7]])


In [7]:
dense_matrix = to_dense_adj(edge_index)
dense_matrix.shape , torch.sum( dense_matrix )

(torch.Size([1, 100, 100]), tensor(142.))

In [8]:
# coustom hypergraph message passing layer with attention score outputs : 

class HypergraphConvWithAttention(nn.Module):
    def __init__(self, in_node_dim, in_edge_dim, out_dim, use_bias=True):
        super().__init__()
        self.node_proj = Linear(in_node_dim, out_dim, bias=use_bias)
        self.edge_proj = Linear(in_edge_dim, out_dim, bias=use_bias)
        self.att_proj = Linear(out_dim, 1, bias=use_bias)
        self.out_proj = Linear(out_dim, out_dim, bias=use_bias)

    def forward(self, x, e, hyperedge_index):
        '''
        x : Node features [N, in_node_dim]
        e : Hyperedge features [M, in_edge_dim]
        hyperedge_index : Edge index [2, E] where E is the number of edges
        output : Node features [N, out_dim] and attention weights [E]
        '''
        row, col = hyperedge_index  # node idx (i), hyperedge idx (j)

        #Project node and hyperedge features
        x_i = self.node_proj(x[row])       # [E, H]
        e_j = self.edge_proj(e[col])       # [E, H]

        #Compute attention weights (tanh + projection)
        att_input = torch.tanh(x_i + e_j)  # [E, H]
        att_scores = self.att_proj(att_input).squeeze(-1)  # [E]

        #Normalize attention weights per node i.e 
        #   across hyperedges per node
        att_weights = softmax(att_scores, index=row)       # [E]
        
        #Weighted message passing from hyperedges to nodes
        messages = e_j * att_weights.unsqueeze(-1)         # [E, H]
        node_updates = scatter_add(messages, row, dim=0, dim_size=x.size(0))  # [N, H]

        # Optional: post-transform (e.g. residual or output layer)
        out = self.out_proj(node_updates)  # [N, H]
        return out, att_weights

In [9]:

class HypergraphConvModelWithAttentionOut(nn.Module):
    """
    Builds a PyG Sequential model with HypergraphConvWithAttention layers,
    returning the final node output and attention weights from the last HGConv.

    Args:
        in_node_dim (int): Input node feature dimension.
        in_edge_dim (int): Input hyperedge feature dimension.
        hidden_dims (list[int]): List of hidden dimensions for each HGConv layer.
        out_dim (int): Output feature dimension for final node prediction.

    Returns:
        model: nn.Module that returns (node_output, last_attention_weights)
    """
    def __init__(
        self , 
        in_node_dim: int, in_edge_dim: int, 
        hidden_dims: list[int], out_dim: int
    ):
        super().__init__()
        self.layers = nn.ModuleList()
        self.activations = nn.ModuleList()

        dims = [in_node_dim] + hidden_dims
        for i in range(len(hidden_dims)):
            self.layers.append(HypergraphConvWithAttention(dims[i], in_edge_dim, dims[i+1]))
            self.activations.append(nn.ReLU())

        self.final_proj = HypergraphConvWithAttention( dims[-1] , in_edge_dim, out_dim )

    def forward(self, x, e, hyperedge_index):
        for conv, act in zip(self.layers, self.activations):
            x, _ = conv(x, e, hyperedge_index)
            x = act(x)
        
        return self.final_proj(x , e, hyperedge_index)

In [10]:
# def buildHGConv( insize:int, outsize:int, features:list , dropout : bool )->Tensor: 
#     layers = [] 
#     layers.append((HypergraphConv( insize , features[0] , dropout=dropout) , 'x, hyperedge_index -> x') )
#     layers.append( nn.ReLU(inplace=True) )
#     for i in range( 1 , len( features ) ): 
#         layers.append(( HypergraphConv( features[i-1] , features[i]  , dropout=dropout) , 'x, hyperedge_index -> x') )
#         layers.append( nn.ReLU(inplace=True) )
#     layers.append((HypergraphConv(features[-1],outsize,dropout=dropout) , 'x, hyperedge_index -> x') )
#     return torch_geometric.nn.Sequential('x, hyperedge_index' , layers)

In [11]:
class HGNN(Module):
    def __init__(
        self,
        node_insize: int,
        node_outsize: int,
        node_features: list,
        hpr_edge_outsize: int,
        hpr_edge_features: list ,
        hg_outsize: int,
        hg_features: list
    ):
        super(HGNN, self).__init__()
        # generate mlp layer for node embedding for hypergraph structure 
        self.mlp_node = buildMLP(node_insize, node_outsize, node_features)
        
        # generate mlp to get embedd hyperedge features
        self.mlp_hyperedge = buildMLP(node_insize, hpr_edge_outsize, hpr_edge_features)
        
        # generate hyergraphconvolution network for the hypergraph structure
        self.hgconv = HypergraphConvModelWithAttentionOut(
            in_node_dim=node_outsize + node_insize,
            in_edge_dim=hpr_edge_outsize + node_insize, 
            hidden_dims=hg_features,
            out_dim=hg_outsize
        )
        
    def forward(self, x: Tensor) -> Tensor: 
        # embedd node features
        node_embedding = self.mlp_node(x)
        # generate hyperedge structure form these node embeddings 
        edge_index , _  = generate_hyperedges( node_embedding , radius_threshold=0.1)
        row , col = edge_index
        
        # generate hyperedge features
        # 1. get the hyperedge features from the node features via mean pooling
        hyperedge_features = global_mean_pool(x[row], edge_index[1])
        # embeddd hyperedge features
        hyperedge_features_embdd = self.mlp_hyperedge(hyperedge_features)
        
        # number of nodes and number of hyperedges
        N = x.shape[0]
        M = hyperedge_features.shape[0]
        
        # append original featuers to the hyperedge features
        hyperedge_features = torch.cat([hyperedge_features, hyperedge_features_embdd], dim=-1)
        # append original features to the node features
        x = torch.cat([x, node_embedding], dim=-1)
        
        # run the hypergraph convolution layer
        _ , att_weights = self.hgconv(
            x = x , e = hyperedge_features, hyperedge_index = edge_index
        )
        
        # get the node to edge scores
        node_to_edge_scores = torch.zeros(N , M)
        node_to_edge_scores[row, col] = att_weights
    
        return node_to_edge_scores  

In [12]:
x = torch.randn(100, 3)  # Example input tensor
y = torch.randn( 100 ,5 )
torch.cat([x, y], dim=-1).shape

torch.Size([100, 8])

In [13]:
model = HGNN(
    node_insize=3,
    node_outsize=5,
    node_features=[10, 20],
    hpr_edge_outsize=5,
    hpr_edge_features=[10, 20],
    hg_outsize=5,
    hg_features=[10, 20]
)

model_example_out = model( node_features )
model_example_out.shape

torch.Size([100, 100])

In [14]:
torch.sum(model_example_out == 0 )

tensor(492)

In [15]:
# class EventLossFunction(nn.Module):
#     def __init__(self):
#         super(EventLossFunction, self).__init__()
#         self.bce_loss = nn.BCELoss()

#     def forward(self, softmax_output: Tensor, labels: Tensor) -> Tensor:
#         """
#         Compute the binary cross-entropy loss for the given softmax output and labels.

#         Args:
#             softmax_output (Tensor): Softmax output of shape (num_nodes, num_hyperedges).
#             labels (Tensor): Labels of shape (num_nodes).

#         Returns:
#             Tensor: Computed loss.
#         """
#         num_nodes, _  = softmax_output.shape

#         # Compute pairwise probabilities
#         pairwise_probs = torch.matmul(softmax_output, softmax_output.T)  # Shape: (num_nodes, num_nodes)

#         # Compute pairwise label agreement
#         pairwise_labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()  # Shape: (num_nodes, num_nodes)

#         # Mask diagonal (self-loops) to avoid self-comparison
#         mask = torch.eye(num_nodes, device=softmax_output.device)
#         pairwise_probs = pairwise_probs * (1 - mask)
#         pairwise_labels = pairwise_labels * (1 - mask)

#         # Flatten and compute BCE loss
#         loss = self.bce_loss(pairwise_probs.flatten(), pairwise_labels.flatten())

#         return loss

In [16]:
class EventLossFunction(nn.Module):
    def __init__(self, batch_size: int = 1024):
        """
        Initialize the EventLossFunction with batching support.

        Args:
            batch_size (int): The size of the batches to process pairwise computations.
        """
        super(EventLossFunction, self).__init__()
        self.bce_loss = nn.BCELoss(reduction='sum')  # Use 'sum' to accumulate loss
        self.batch_size = batch_size

    def forward(self, softmax_output: Tensor, labels: Tensor) -> Tensor:
        """
        Compute the binary cross-entropy loss for the given softmax output and labels using batching.

        Args:
            softmax_output (Tensor): Softmax output of shape (num_nodes, num_hyperedges).
            labels (Tensor): Labels of shape (num_nodes).

        Returns:
            Tensor: Computed loss as a single scalar tensor.
        """
        num_nodes, _ = softmax_output.shape
        device = softmax_output.device

        # Initialize loss accumulator
        total_loss = torch.tensor(0.0, device=device)

        # Compute pairwise probabilities and labels in batches
        for i in range(0, num_nodes, self.batch_size):
            end_i = min(i + self.batch_size, num_nodes)
            softmax_batch = softmax_output[i:end_i]  # Shape: (batch_size, num_hyperedges)

            # Pairwise probabilities for the current batch
            pairwise_probs = torch.matmul(softmax_batch, softmax_output.T)  # Shape: (batch_size, num_nodes)

            # Pairwise label agreement for the current batch
            pairwise_labels = (labels[i:end_i].unsqueeze(1) == labels.unsqueeze(0)).float()  # Shape: (batch_size, num_nodes)

            # Mask diagonal (self-loops) to avoid self-comparison
            mask = torch.eye(end_i - i, num_nodes, device=device)
            pairwise_probs = pairwise_probs * (1 - mask)
            pairwise_labels = pairwise_labels * (1 - mask)

            # Flatten and compute BCE loss for the current batch
            batch_loss = self.bce_loss(pairwise_probs.flatten(), pairwise_labels.flatten())
            total_loss += batch_loss

        # Normalize the total loss by the number of pairs
        total_pairs = torch.tensor( num_nodes * (num_nodes - 1) )  # Total number of valid pairs
        return total_loss / total_pairs


In [17]:
random_label = torch.randint(low=0, high=10, size=(100,))
random_label

tensor([2, 1, 0, 5, 8, 4, 9, 9, 6, 6, 1, 3, 3, 5, 0, 7, 9, 0, 8, 7, 7, 9, 1, 0,
        3, 2, 7, 6, 8, 6, 3, 2, 0, 7, 7, 3, 7, 2, 7, 7, 5, 3, 9, 1, 6, 0, 5, 1,
        9, 0, 7, 4, 1, 2, 8, 4, 9, 0, 9, 3, 8, 8, 7, 8, 9, 1, 0, 9, 3, 7, 1, 8,
        3, 4, 9, 3, 6, 4, 8, 5, 9, 7, 2, 5, 1, 7, 3, 7, 8, 7, 2, 8, 5, 5, 3, 1,
        3, 0, 3, 4])

In [18]:
loss = EventLossFunction( )
loss( model_example_out , random_label )

tensor(0.4746, grad_fn=<DivBackward0>)

In [19]:
def create_disjoint_hypergraph(softmax_output: Tensor) -> dict:
    """
    Create a disjoint hypergraph by assigning each node to its most probable hyperedge.

    Args:
        softmax_output (Tensor): Softmax output of shape (num_nodes, num_hyperedges).

    Returns:
        dict: A dictionary where keys are hyperedge indices and values are tensors of node indices.
    """
    # Find the most probable hyperedge for each node
    most_probable_hyperedges = torch.argmax(softmax_output, dim=-1)
    # Create a disjoint hypergraph by grouping nodes based on their assigned hyperedges
    disjoint_hypergraph = {hyperedge.item(): [] for hyperedge in most_probable_hyperedges.unique()}
    for node, hyperedge in enumerate(most_probable_hyperedges):
        disjoint_hypergraph[hyperedge.item()].append(node)

    # Convert the disjoint hypergraph to a more readable format
    disjoint_hypergraph = {key: torch.tensor(value) for key, value in disjoint_hypergraph.items()}

    return disjoint_hypergraph

In [20]:
disjoint_hypergraph = create_disjoint_hypergraph(model_example_out)
print(disjoint_hypergraph)

{2: tensor([12, 28, 33, 46]), 4: tensor([87]), 5: tensor([47]), 6: tensor([45, 57, 64]), 14: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  9, 10, 11, 14, 15, 16, 17, 18, 19, 20,
        21, 22, 23, 24, 25, 26, 27, 29, 30, 31, 32, 34, 35, 36, 38, 39, 40, 41,
        42, 43, 44, 48, 49, 50, 52, 53, 54, 55, 56, 58, 59, 60, 61, 62, 63, 65,
        66, 67, 68, 69, 71, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85,
        86, 88, 89, 90, 91, 93, 94, 95, 96, 97, 98, 99]), 19: tensor([ 8, 13, 37]), 23: tensor([51, 92]), 30: tensor([72]), 88: tensor([70])}


In [21]:
import pytorch_lightning as pl
from pytorch_lightning import LightningModule
from torchmetrics import Metric
import yaml 

In [22]:
class ParticlePurity(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("total_purity", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("num_events", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, disjoint_hypergraph: dict, labels: Tensor):
        """
        Update the metric state with a new disjoint hypergraph and corresponding labels.

        Args:
            disjoint_hypergraph (dict): A dictionary where keys are hyperedge indices and values are tensors of node indices.
            labels (Tensor): A tensor of shape (num_nodes,) containing the labels for each node.
        """
        event_purity = 0.0
        num_particles = len( torch.unique(labels) )

        for _, nodes in disjoint_hypergraph.items():
            node_labels = labels[nodes]
            most_common_label = torch.mode(node_labels).values
            intersection = (node_labels == most_common_label).sum().item()

            num_particles_with_most_common_label = (labels == most_common_label).sum().item()
            
            if intersection >= 0.5 * len(node_labels) and intersection >= 0.5 * num_particles_with_most_common_label:
                event_purity += intersection / num_particles_with_most_common_label

        self.total_purity += event_purity / num_particles
        self.num_events += 1

    def compute(self):
        """
        Compute the average particle purity over all events.

        Returns:
            Tensor: The average particle purity.
        """
        return self.total_purity / self.num_events

In [23]:
class TrackPurity(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("total_purity", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("num_events", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, disjoint_hypergraph: dict, labels: Tensor):
        """
        Update the metric state with a new disjoint hypergraph and corresponding labels.

        Args:
            disjoint_hypergraph (dict): A dictionary where keys are hyperedge indices and values are tensors of node indices.
            labels (Tensor): A tensor of shape (num_nodes,) containing the labels for each node.
        """
        event_purity = 0.0
        num_tracks = len( disjoint_hypergraph )

        for _, nodes in disjoint_hypergraph.items():
            node_labels = labels[nodes]
            most_common_label = torch.mode(node_labels).values
            intersection = (node_labels == most_common_label).sum().item()

            num_particles_with_most_common_label = (labels == most_common_label).sum().item()
            
            if intersection >= 0.5 * len(node_labels) and intersection >= 0.5 * num_particles_with_most_common_label:
                event_purity += intersection / len(node_labels)

        self.total_purity += event_purity / num_tracks
        self.num_events += 1

    def compute(self):
        """
        Compute the average particle purity over all events.

        Returns:
            Tensor: The average particle purity.
        """
        return self.total_purity / self.num_events

In [24]:
from TrackML.Embedding.utils import PointCloudData,train_test_split

### Define pytorch lightning dataset : 
class EmbeddingDataset(pl.LightningDataModule): 
    
    # initialize the class : 
    def __init__(self,hparams)->None: 
        super().__init__() 
        self.save_hyperparameters(hparams)
        
    # def prepare_data(self)->None: 
        # self.detector = Preprocessing.load_detector_data(self.hparams['detector_path'])
        # get the list of event ids from the dataset folder : 
        # self.eventids = [ code[:-9] for code in os.listdir(self.hparams['dataset_path']) if code.endswith('-hits.csv') ]
        # self.dataset = PointCloudData(dataset_path=self.hparams['dataset_path'] , detector_path=self.hparams['detector_path'] , min_nhits=self.hparams['min_hits'] )
    
    def setup(self,stage=None)->None: 
        self.dataset = PointCloudData(dataset_path=self.hparams['dataset_path'] , detector_path=self.hparams['detector_path'] , min_nhits=self.hparams['min_hits'] , max_r = self.hparams['max_r'] , drop_fake= self.hparams['drop_fake'] )
        self.train_ds , self.val_ds , self.test_ds = train_test_split(
            dataset=self.dataset, valid_size=self.hparams['valid_size'], 
            test_size=self.hparams['test_size'], num_works=self.hparams['num_works']
        )
        
    def train_dataloader(self): 
        return self.train_ds 
    def val_dataloader(self): 
        return self.val_ds 
    def test_dataloader(self): 
        return self.test_ds

In [25]:
with open( '05-Hypergraph-Model.yml' , 'r' ) as f : 
    hparams = yaml.safe_load(f)

In [26]:
class HGNN_TrackML(LightningModule): 
    
    def __init__(self,hparams): 
        super().__init__()
        self.save_hyperparameters(hparams)
        
        # Metrics (with DDP support)
        self.particle_purity = ParticlePurity()
        self.track_purity = TrackPurity()
        
        # Losses
        self.loss = EventLossFunction()
        
        self.model = HGNN(
            node_insize = self.hparams['node_insize'],
            node_outsize = self.hparams['node_outsize'],
            node_features = self.hparams['node_features'],
            hpr_edge_outsize = self.hparams['hpr_edge_outsize'],
            hpr_edge_features = self.hparams['hpr_edge_features'],
            hg_outsize = self.hparams['hg_outsize'],
            hg_features = self.hparams['hg_features']
        )
        
        self.last_idx = 0
        
    def forward(self , x ): 
        return self.model( x )
    
    def training_step(self , batch , batch_idx ): 
        
        event_data , labels = batch 
        event_data.squeeze_(dim=0)
        labels.squeeze_(dim=0)
        
        output = self( event_data )
        
        loss = self.loss( output , labels )
        self.log(
            'Loss' , loss, 
            prog_bar = True , on_step = True , on_epoch = False 
        )
        
        if torch.cuda.is_available():
            self.log(
                'Memory Allocated' , torch.cuda.memory_allocated()/(1024**3), 
                prog_bar=True , on_step = True , on_epoch=True, 
                reduce_fx='max'
            )
        self.last_idx = batch_idx
        return loss 
    
    def on_train_batch_end(self, outputs, batch, batch_idx):
        if self.last_idx % 10 == 0 and torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect() 
    
    # common logic for test and validation 
    def _test_val_common_step_(self , batch , batch_idx ): 
        
        with torch.no_grad(): 
            event_data , labels = batch 
            event_data.squeeze_(dim=0)
            labels.squeeze_(dim=0)
            
            output = self( event_data )
            
            loss = self.loss( output , labels ) 
            
            disjoint_hypergraph = create_disjoint_hypergraph(output)
            self.particle_purity.update(disjoint_hypergraph, labels)
            self.track_purity.update(disjoint_hypergraph, labels)
            
            self.last_idx = batch_idx
            
            return loss 
    
    def validation_step(self,batch,batch_idx):
        return self._test_val_common_step_(batch,batch_idx)
    
    def test_step(self,batch,batch_idx): 
        return self._test_val_common_step_(batch,batch_idx)
    
    def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx = 0):
        if self.last_idx % 10 == 0 and torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
    
    def on_validation_epoch_end(self):
        self.log_dict({
            "Particle Purity": self.particle_purity.compute(),
            "Track Purity": self.track_purity.compute(),
        }, prog_bar=True)
        self.particle_purity.reset()
        self.track_purity.reset()

        
    def on_test_epoch_end(self):
        self.log_dict({
            "Particle Purity": self.particle_purity.compute(),
            "Track Purity": self.track_purity.compute(),
        }, prog_bar=True)
        self.particle_purity.reset()
        self.track_purity.reset()
    
    def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx = 0):
        if self.last_idx % 10 == 0 and torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
    
    # set the optimizer  :
    def configure_optimizers(self): 
        return torch.optim.SGD(self.model.parameters(),lr = self.hparams['lr'])

In [27]:
from pytorch_lightning import Trainer

In [None]:
model = HGNN_TrackML(hparams)
ds = EmbeddingDataset(hparams)

trainer = Trainer(
    accelerator = "cpu", 
    devices = "auto",
    enable_checkpointing=False, 
    # max_epochs=1,
    fast_dev_run = 2 , 
)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/ashmitbathla/Documents/UGP-Local/TrackML/TrackMLVenv/lib/python3.10/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
Running in `fast_dev_run` mode: will run the requested loop using 2 batch(es). Logging and checkpointing is suppressed.


: 

In [None]:
if __name__ == "__main__":
    trainer.fit(model , ds )


  | Name            | Type              | Params | Mode 
--------------------------------------------------------------
0 | particle_purity | ParticlePurity    | 0      | train
1 | track_purity    | TrackPurity       | 0      | train
2 | loss            | EventLossFunction | 0      | train
3 | model           | HGNN              | 6.0 K  | train
--------------------------------------------------------------
6.0 K     Trainable params
0         Non-trainable params
6.0 K     Total params
0.024     Total estimated model params size (MB)
33        Modules in train mode
0         Modules in eval mode
/Users/ashmitbathla/Documents/UGP-Local/TrackML/TrackMLVenv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/Users/ashmitbathla/Documents/UGP-Local/Track

Training: |          | 0/? [00:00<?, ?it/s]