# TrackML Particle Tracking 


**Dataset used: https://www.kaggle.com/c/trackml-particle-identification.**


The Large Hadron Collider (LHC) is the world's largest and most powerful particle accelerator, located at CERN (European Organization for Nuclear Research) on the Franco-Swiss border. In this project, we analyze data produced from the collisions by looking at the information from silicon detectors. The objective of this work is to identify the original tracks of the particles by joining the 3D coordinates of the hits.

With this aim, we will train a GNN by identifying the hits with the nodes and the edges with the conexions between nodes. Our GNN model must predict which is the real track of the particle, by selecting the correct edges between hits.

# Environment Setup

We can use this library to use our dataset more easily:

In [None]:
# !pip install git+https://github.com/LAL/trackml-library.git

In [None]:
import os
import torch
import pandas as pd
import numpy as np
import plotly.express as px

# Use GPU if available (Essential for GNN training)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Running on: {device}")

# Data loading

We'll start by taking only 5000 hits of one event, so that we can run it on the cpu. We focus on the Pixel Tracker to reduce the compilation time, and filter the hits that occur in the proximities of the center.

In [None]:
from trackml.dataset import load_event

hits, cells, particles, truth = load_event('train_sample/train_100_events/event000001000')

# Join hits with truth data to identify which hit belongs to which particle
hits = hits.merge(truth[['hit_id', 'particle_id']], on='hit_id')

# We filter for the Pixel Tracker (the volumes of interest)
# These volumes (8, 13, 17) represent the innermost layers of the detector.

hits = hits[hits.volume_id.isin([8, 13, 17])]

# Next, we take a manageable sample:

hits = hits[
    (hits.volume_id.isin([8, 13, 17])) & 
    (hits.z > -50) & (hits.z < 50)
].copy()


print(hits.shape)

We have considerably reduced the number of instances.

# Defining the GNN Architecture (Interaction Network)

##  Cylindrical coordinates

The $(x, y, z)$ coordinates represent the exact spatial position of a "hit" (impact) within the detector, measured with micrometric precision, by using a cartesian coordinate system with:
- Origin $(0, 0, 0)$: This is the geometric center of the detector, where the proton beams collide.Z-axis: Extends along the beam pipe (where protons travel before colliding). Particles traveling "forward" or "backward" have very large $z$ values (positive or negative).
- XY Transverse Plane: This is the plane perpendicular to the beam. This is where CERN's magnetic field curves particle trajectories so we can measure their momentum.Units: All $(x, y, z)$ coordinates are expressed in millimeters (mm).

For tracking algorithms and your GNN, we often convert these Cartesian coordinates into cylindrical ones ($r, \phi, z$), as the detector has a cylindrical symmetry:
- Radius ($r$): Calculated as $r = \sqrt{x^2 + y^2}$. It indicates how far from the center the particle has reached.
- Azimuthal Angle ($\phi$): Calculated as $\phi = \arctan2(y, x)$. It indicates the particle's direction in the detector's circle.
- Z: Remains the same.

## Graph Building


Our goal is to build an Interaction Network. This model will perform Edge Classification: it looks at a connection between two hits and predicts a value between 0 and 1.

- Target 1: True segment (both hits belong to the same particle_id).

- Target 0: Fake segment (noise or different particles).

However, to build a graph that is actually usable for a GNN, we cannot simply connect every hit to every other hit. This would result in too many edges, which is physically impossible to compute. To reduce the number of edges, we add some constraints based on the laws of physics:

- **Layer-to-Layer Constraint:** It's the most obvious one. Particles are created at the center and travel outward. A hit in an inner layer can only be logically followed by a hit in an outer layer. By enforcing this, we eliminate millions connections that make no sense physically.

- **Angular Difference ($\Delta \phi$):** The LHC uses a powerful magnetic field to curve the paths of charged particles. However, high-energy particles (the ones we care about) have a high momentum and do not curve sharply. Their angular direction ($\phi$) changes very little between adjacent layers. 

- **Z-Distance ($\Delta z$):** Since the particles originate from the beam-collision point at the center, their trajectory in the $Z$ direction (along the beam pipe) is relatively straight. If two hits have a massive vertical distance ($\Delta z$) but are close in the $XY$ plane (which may happen if they are in adjacent layers), it is highly unlikely that they belong to the same track.


Let's build a function that computes the nodes of our model, taking in consideration the previous constraints:

In [None]:
import numpy as np
import torch
from torch_geometric.data import Data

def build_graph(hits, delta_phi_max=0.5, delta_z_max=20):
   
    hits['r'] = np.sqrt(hits.x**2 + hits.y**2)
    hits['phi'] = np.arctan2(hits.y, hits.x)
    
    # Create a unique layer identifier (Global Layer)
    # Combining volume_id and layer_id to distinguish layers across different volumes
    hits['global_layer'] = hits.volume_id * 100 + hits.layer_id
    
    # Get a sorted list of all unique global layers present in the data
    ordered_layers = np.sort(hits.global_layer.unique())
    
    # Normalization
    # Instead of /1000, it's better to have values ​​close to 0 with a mean of 1
    x_features = torch.tensor(hits[['x', 'y', 'z', 'r', 'phi']].values, dtype=torch.float)
    x_features[:, :3] /= 1000.0
    x_features[:, 3] /= 1000.0 
    x_features[:, 4] /= np.pi  
    
    # Extract coordinate arrays for faster vectorized filtering
    phi = np.arctan2(hits.y.values, hits.x.values)
    z = hits.z.values
    layer_array = hits.global_layer.values
    particle_ids = hits.particle_id.values
    
    senders, receivers, y = [], [], []

    # Iterate layer by layer (except the last one) to find adjacent connections
    for layer_idx in range(len(ordered_layers) - 1):
        current_layer = ordered_layers[layer_idx]
        next_layer = ordered_layers[layer_idx + 1]
        
        # Get indices of hits belonging to the current and the next adjacent layer
        current_indices = np.where(layer_array == current_layer)[0]
        next_indices = np.where(layer_array == next_layer)[0]
        
        for i in current_indices:
            # Filter candidates ONLY in the immediately succeeding layer

            d_phi = np.abs(phi[next_indices] - phi[i])
            d_phi = np.where(d_phi > np.pi, 2*np.pi - d_phi, d_phi)   # Periodicity Management

            potential_neighbors = next_indices[(d_phi < delta_phi_max) & (np.abs(z[next_indices] - z[i]) < delta_z_max)]
            
            for j in potential_neighbors:
                senders.append(i)
                receivers.append(j)
                
                # Assign Truth Label: 1 if both hits belong to the same particle (and not noise)
                is_same_particle = (particle_ids[i] == particle_ids[j]) and (particle_ids[i] != 0)
                y.append(1 if is_same_particle else 0)

    # Convert edge lists to PyTorch Tensors
    edge_index = torch.tensor([senders, receivers], dtype=torch.long)
    edge_label = torch.tensor(y, dtype=torch.float)
    
    # Return the PyTorch Geometric Data object
    return Data(x=x_features, edge_index=edge_index, y=edge_label)

# Process the first event
graph_data = build_graph(hits)
print(f"Graph created with {graph_data.num_nodes} nodes and {graph_data.num_edges} edges.")

## Graph visualization

First, we visualize both true and false edges.

In [None]:
import plotly.graph_objects as go
import numpy as np

def plot_graph_3d(data, n_edges=500):

    # We count the number of false and true edges represented
    number_false_edges=0
    number_true_edges=0

    # Extract Node Positions (Rescale back to mm)
    pos = data.x.cpu().numpy() * 1000.0 
    
    # Create the Hit (Node) Scatter Plot
    node_trace = go.Scatter3d(
        x=pos[:, 0], y=pos[:, 1], z=pos[:, 2],
        mode='markers',
        marker=dict(size=2, color='blue', opacity=0.8),
        name='Hits'
    )

    # Create Edge Traces
    # We only plot a subset of edges to avoid crashing the browser
    edge_index = data.edge_index.cpu().numpy()
    edge_y = data.y.cpu().numpy()
    
    # Separate True and Fake edges for coloring
    edge_x, edge_y_coords, edge_z = [], [], []
    true_edge_x, true_edge_y, true_edge_z = [], [], []

    for i in range(min(n_edges, data.num_edges)):
        start_node = edge_index[0, i]
        end_node = edge_index[1, i]
        
        # Segment coordinates
        x_coords = [pos[start_node, 0], pos[end_node, 0], None]
        y_coords = [pos[start_node, 1], pos[end_node, 1], None]
        z_coords = [pos[start_node, 2], pos[end_node, 2], None]
        
        if edge_y[i] == 1:
            true_edge_x.extend(x_coords); true_edge_y.extend(y_coords); true_edge_z.extend(z_coords)
            number_true_edges+=1
        else:
            edge_x.extend(x_coords); edge_y_coords.extend(y_coords); edge_z.extend(z_coords)
            number_false_edges+=1

    # Fake edges (Red)
    fake_trace = go.Scatter3d(
        x=edge_x, y=edge_y_coords, z=edge_z,
        mode='lines', line=dict(color='red', width=1),
        name='Fake Edges (Proposed)', opacity=0.3
    )
    
    # True edges (Green)
    true_trace = go.Scatter3d(
        x=true_edge_x, y=true_edge_y, z=true_edge_z,
        mode='lines', line=dict(color='green', width=3),
        name='True Edges (Signal)'
    )

    # Final Layout
    fig = go.Figure(data=[node_trace, fake_trace, true_trace])
    fig.update_layout(
        title="3D Graph Representation: Internal Pixel Tracker",
        scene=dict(xaxis_title='X (mm)', yaxis_title='Y (mm)', zaxis_title='Z (mm)', aspectmode='data'),
        margin=dict(l=0, r=0, b=0, t=40)
    )
    fig.show()

    print(f"The representation contains {number_true_edges} true edges and {number_false_edges} false edges")

# Visualize the first edges of your graph
plot_graph_3d(graph_data, n_edges=617136)

To get a more clear visualization, we can plot only the true tracks.

In [None]:
import plotly.graph_objects as go
import numpy as np

def plot_true_edges_3d(data, n_edges=500):
    """
    Plots a 3D representation of the hits and ONLY the true (signal) edges.
    """
    # Counter for true edges found in the sampled range
    number_true_edges = 0

    # Extract Node Positions (Rescale back to mm)
    # Assuming features are normalized, we scale by 1000 for mm units
    pos = data.x.cpu().numpy() * 1000.0 
    
    # Create the Hit (Node) Scatter Plot
    node_trace = go.Scatter3d(
        x=pos[:, 0], y=pos[:, 1], z=pos[:, 2],
        mode='markers',
        marker=dict(size=2, color='blue', opacity=0.8),
        name='Hits (Detector Hits)'
    )

    # Create True Edge Trace
    edge_index = data.edge_index.cpu().numpy()
    edge_y = data.y.cpu().numpy()
    
    # Lists to store coordinates for true edges
    true_edge_x, true_edge_y, true_edge_z = [], [], []

    # Iterate through the edges up to the specified limit
    for i in range(min(n_edges, data.num_edges)):
        # Only process if it is a TRUE edge (y = 1)
        if edge_y[i] == 1:
            start_node = edge_index[0, i]
            end_node = edge_index[1, i]
            
            # Append coordinates for the line segment [Start, End, None to break the line]
            true_edge_x.extend([pos[start_node, 0], pos[end_node, 0], None])
            true_edge_y.extend([pos[start_node, 1], pos[end_node, 1], None])
            true_edge_z.extend([pos[start_node, 2], pos[end_node, 2], None])
            
            number_true_edges += 1

    # Define the Green Trace for True Edges
    true_trace = go.Scatter3d(
        x=true_edge_x, y=true_edge_y, z=true_edge_z,
        mode='lines', 
        line=dict(color='green', width=3),
        name='True Edges (Signal Tracks)',
        opacity=1.0
    )

    # Final Layout Configuration
    fig = go.Figure(data=[node_trace, true_trace])
    fig.update_layout(
        title="3D Particle Tracking: Ground Truth Signal Edges",
        scene=dict(
            xaxis_title='X (mm)', 
            yaxis_title='Y (mm)', 
            zaxis_title='Z (mm)',
            aspectmode='data' # Maintains physical proportions
        ),
        margin=dict(l=0, r=0, b=0, t=40)
    )
    
    fig.show()

    print(f"Visualization complete. Represented {number_true_edges} true edges.")

plot_true_edges_3d(graph_data, n_edges=617136)

## Model arquitecture

We'll start by deciding how the GNN will learn the relationships between nodes.

In [None]:
class TrackNet(MessagePassing):
    def __init__(self, node_in=5, edge_in=15, hidden=128):
        super(TrackNet, self).__init__(aggr='add')
        self.edge_mlp = nn.Sequential(
            nn.Linear(edge_in, hidden),
            nn.BatchNorm1d(hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1)
        )
        self.node_mlp = nn.Sequential(
            nn.Linear(node_in + 1, hidden),
            nn.ReLU(),
            nn.Linear(hidden, node_in)
        )
        self.current_logits = None       # To store the edge scores

    def forward(self, x, edge_index):    # Model output 
        self.propagate(edge_index, x=x)
        return self.current_logits.squeeze(-1)

    def message(self, x_i, x_j):

        delta_pos = x_j[:, :3] - x_i[:, :3]
        delta_r = x_j[:, 3:4] - x_i[:, 3:4]
        delta_phi = x_j[:, 4:5] - x_i[:, 4:5]        
        
        edge_input = torch.cat([x_i, x_j, delta_pos, delta_r, delta_phi], dim=1)
        
        self.current_logits = self.edge_mlp(edge_input)
        return self.current_logits
    
    def update(self, aggr_out, x):
        """
        FINAL STEP OF THE ALGORITHM:
        aggr_out: The sum of scores (logits) of all edges touching each hit.
        x: The original hit features [x, y, z, r, phi].
        """
        # Squash the summed message so it doesn't "drown out" the x coordinates
        msg_scaled = torch.sigmoid(aggr_out)
        
        # Now x (normalized) and msg_scaled (0-1) have the same statistical weight
        node_input = torch.cat([x, msg_scaled], dim=1)
        
        # Concatenate the hit position with the sum of its messages (5 + 1 = 6 columns)
        node_input = torch.cat([x, aggr_out], dim=1) 
        
        # The node_mlp processes this union and returns the "contextualized" node
        return self.node_mlp(node_input)

And we create our model object.

In [None]:
model = TrackNet().to(device)

## Training

First, we can observe that the edge classes are strongly unbalanced.

In [None]:
num_pos = (graph_data.y == 1).sum().item()
num_neg = (graph_data.y == 0).sum().item()
print(f"Positive edges: {num_pos}, Negative edges: {num_neg}")

In the training process, we will use Binary Cross Entropy (BCE) because this is a binary classification problem, assigning different weights to each class to take into consideration its imbalance. We use the Adam optimizer for its fast convergence.

In [None]:
import torch.optim as optim
from sklearn.metrics import roc_auc_score
import time

# Initialize model
optimizer = optim.Adam(model.parameters(), lr=0.001)        # Adam optimizer

weight = torch.tensor([611323 / 5813]).to(device)
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=weight)   # BCE

# Move data to the same device as the model
graph_data = graph_data.to(device)

print(f"Training on: {device}")

We'll use the metric AUC-ROC for binary classification of edges. First, we'll build a function that computes one complete epoch in our training data.

In [None]:
def train_one_epoch():
    model.train()
    optimizer.zero_grad() 
    
    edge_logits = model(graph_data.x, graph_data.edge_index)
    
    #  We calculate the loss using the model's logits
    # 'graph_data.y' are the ground truth labels (1 if it is a real trajectory, 0 if not)
    loss = criterion(edge_logits, graph_data.y.float())
    
    loss.backward()

    # Gradient Clipping to prevent gradients from exploding
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
    optimizer.step()
    
    # Return the loss and the scores (logits) to monitor progress
    return loss.item(), edge_logits

Now, we can apply it to our data to train for 400 epochs.

In [None]:
# Training for 200 epochs
for epoch in range(1, 401):
    loss_val, scores = train_one_epoch()
    
    if epoch % 10 == 0:
        # Calculate AUC-ROC: A value of 1.0 means perfect classification
        auc = roc_auc_score(graph_data.y.cpu().detach(), scores.cpu().detach())
        print(f"Epoch {epoch:03d} | Loss: {loss_val:.4f} | AUC: {auc:.4f}")

We can observe that we get a considerably high AUC (near 0.95).

## Expected score 

We build a function to predict the test edges:

In [None]:
import torch

@torch.no_grad()   # Disables gradient calculation to save memory
def predict_test_edges(model, test_data):
    model.eval()   # Sets the model to evaluation mode (fixes BatchNorm/Dropout)
    
    # Forward pass: 
    edge_logits = model(test_data.x, test_data.edge_index)
     
    # Diagnostic 
    print(f"Max Logit: {edge_logits.max().item():.4f}")
    print(f"Min Logit: {edge_logits.min().item():.4f}")
    print(f"Mean Logits: {edge_logits.mean().item():.4f}")

    # Convert logits to probabilities
    # We use the sigmoid function to map from range (-inf, +inf) to range (0, 1)
    probs = torch.sigmoid(edge_logits)
    
    # Move to CPU and NumPy format for the submission dataframe
    return probs.cpu().numpy()

We'll start by using a _connected components_ solution to reconstruct the tracks:

In [None]:
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components

def get_submission_df(test_data, edge_probs, threshold=0.5):
    # Filter edges by probability
    mask = edge_probs > threshold
    rows = test_data.edge_index[0, mask].cpu().numpy()
    cols = test_data.edge_index[1, mask].cpu().numpy()
    
    # Create adjacency matrix and calculate connected components
    num_nodes = test_data.x.size(0)
    adj = csr_matrix((np.ones(len(rows)), (rows, cols)), shape=(num_nodes, num_nodes))
    _, labels = connected_components(csgraph=adj, directed=False)
    
    # Format required by the TrackML metric
    return pd.DataFrame({
        'hit_id': np.arange(num_nodes),
        'track_id': labels
    })

Finally, we can plot the predicted probabilities of the test edges (which are the train ones).

In [None]:
import matplotlib.pyplot as plt

test_data = build_graph(hits)

truth_test_df = truth[truth.hit_id.isin(hits.hit_id)].copy()

# We get the probabilities 
probs = predict_test_edges(model, test_data)

# We can plot them on a bar plot
plt.hist(probs, bins=50)
plt.title("Distribución de Probabilidades")
plt.show()

We get the predicted tracks and represent them: 

In [None]:
predicted_tracks = get_submission_df(test_data, probs, threshold=0.8)

print(f"Pistas únicas encontradas: {predicted_tracks['track_id'].nunique()}")
print(f"Total de hits: {len(predicted_tracks)}")

In [None]:
import plotly.graph_objects as go
import numpy as np

def plot_final_tracks_3d(test_data, predicted_tracks, num_tracks=30, min_hits=3):
    """
    Visualizes the trajectories reconstructed by the model.
    """
    # Retrieve coordinates (Make sure to multiply by the normalization factor used)
    coords = test_data.x[:, :3].cpu().numpy() * 1000.0 
    
    df_plot = predicted_tracks.copy()
    df_plot['x'], df_plot['y'], df_plot['z'] = coords[:, 0], coords[:, 1], coords[:, 2]
    
    # Calculate Radius (r) to SORT the hits
    df_plot['r'] = np.sqrt(df_plot['x']**2 + df_plot['y']**2)
    
    # Filter tracks for visualization
    track_counts = df_plot['track_id'].value_counts()
    
    # We filter tracks with a realistic number of hits (Pixel Tracker physics: 3-15 hits)
    # We avoid plotting a "super track" if it exists, as it would hide everything else
    valid_tracks = track_counts[(track_counts >= min_hits) & (track_counts < 30)].index.tolist()
    
    if not valid_tracks:
        print("No valid tracks were found with the current threshold and filters.")
        return

    selected_ids = np.random.choice(valid_tracks, min(num_tracks, len(valid_tracks)), replace=False)
    
    fig = go.Figure()

    for t_id in selected_ids:
        # Extract hits of the track and sort by radius r (from the center outward)
        track_data = df_plot[df_plot['track_id'] == t_id].sort_values('r')
        
        fig.add_trace(go.Scatter3d(
            x=track_data['x'], y=track_data['y'], z=track_data['z'],
            mode='lines+markers',
            line=dict(width=4),
            marker=dict(size=3, opacity=0.8),
            name=f'Track {t_id} ({len(track_data)} hits)'
        ))

    fig.update_layout(
        title=f"Visualization of {len(selected_ids)} Reconstructed Trajectories",
        scene=dict(xaxis_title='X (mm)', yaxis_title='Y (mm)', zaxis_title='Z (mm)', aspectmode='data'),
        margin=dict(l=0, r=0, b=0, t=40)
    )
    fig.show()

# Execution
plot_final_tracks_3d(test_data, predicted_tracks)

We get very poor results. We can use a greedy solution instead:

In [None]:
import plotly.graph_objects as go
import numpy as np

def plot_all_reconstructed_tracks(test_data, hits_df, edge_probs, threshold=0.9, max_tracks=50):
    """
    Reconstructs multiple tracks using the Greedy method and plots them in 3D.
    """
    # Retrieve necessary data
    src = test_data.edge_index[0].cpu().numpy()
    dst = test_data.edge_index[1].cpu().numpy()
    probs = edge_probs.cpu().detach().numpy() if torch.is_tensor(edge_probs) else edge_probs
    pos = test_data.x[:, :3].cpu().numpy() * 100.0  # Convert back to mm
    radii = test_data.x[:, 3].cpu().numpy()
    
    num_nodes = test_data.x.size(0)
    used_hits = np.zeros(num_nodes, dtype=bool)
    fig = go.Figure()

    # Identify Seeds (Innermost layer)
    min_layer = hits_df.global_layer.min()
    seeds = np.where(hits_df.global_layer.values == min_layer)[0]
    
    tracks_found = 0
    
    # Reconstruction Loop
    for seed in seeds:
        if tracks_found >= max_tracks: break
        if used_hits[seed]: continue
        
        track = [seed]
        curr = seed
        
        # Build the track following the most probable path
        for _ in range(15):
            out_mask = (src == curr)
            if not np.any(out_mask): break
            
            # Filters: Probability, Unused, and Increasing radius (outward)
            candidates_dst = dst[out_mask]
            candidates_probs = probs[out_mask]
            
            valid_mask = (candidates_probs > threshold) & \
                         (~used_hits[candidates_dst]) & \
                         (radii[candidates_dst] > radii[curr] + 0.0001)
            
            if not np.any(valid_mask): break
            
            # Select the best candidate and advance
            best_idx = np.argmax(np.where(valid_mask, candidates_probs, -1))
            curr = candidates_dst[best_idx]
            track.append(curr)
        
        # If the track is valid, add it to the plot
        if len(track) >= 3:
            for h in track: used_hits[h] = True # Mark hits as used to prevent 'super tracks'
            
            track_pos = pos[track]
            fig.add_trace(go.Scatter3d(
                x=track_pos[:, 0], y=track_pos[:, 1], z=track_pos[:, 2],
                mode='lines+markers',
                line=dict(width=3),
                marker=dict(size=2),
                name=f'Track {tracks_found}'
            ))
            tracks_found += 1

    # Plot configuration
    fig.update_layout(
        title=f"Visualization of {tracks_found} Greedy Trajectories (Threshold {threshold})",
        scene=dict(xaxis_title='X (mm)', yaxis_title='Y (mm)', zaxis_title='Z (mm)', aspectmode='data'),
        margin=dict(l=0, r=0, b=0, t=40)
    )
    fig.show()
    print(f"{tracks_found} individual trajectories have been plotted.")

plot_all_reconstructed_tracks(test_data, hits, probs, threshold=0.70, max_tracks=200)

Results seem to be more precise, but are still not perfect due to the imbalance of the data and other unexplored factors.

**Better solutions are being explored.**