# GNN model with HGCAL data

## Importing the required packages

In [1]:
import uproot
import awkward as ak
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_add_pool

from sklearn.neighbors import NearestNeighbors


## ROOT -----> Numpy arrays

In [4]:
root_file="ntuple_pi+_100GeV_100keve.root"
tree_name="AllLayers"

In [3]:
file=uproot.open(root_file)

In [None]:
#file.arrays(["hit_x"],library="ak")

In [5]:
class RootEventReader:
    def __init__(self, root_file, tree_name):
        file = uproot.open(root_file)
        tree = file[tree_name]

        # Read all required branches ONCE
        arrays = tree.arrays(["hit_x", "hit_y", "hit_z", "hit_l", "hit_E"],library="ak")

        # Store as awkward arrays (event-wise)
        self.hit_x = arrays["hit_x"]
        self.hit_y = arrays["hit_y"]
        self.hit_z = arrays["hit_z"]
        self.hit_l = arrays["hit_l"]
        self.hit_E = arrays["hit_E"]

        self.n_events = len(self.hit_x)

    def get_event(self, idx):
        """
        Returns NumPy arrays for a single event
        (same interface as before)
        """
        return (ak.to_numpy(self.hit_x[idx]),
                ak.to_numpy(self.hit_y[idx]),
                ak.to_numpy(self.hit_z[idx]),
                ak.to_numpy(self.hit_l[idx]),
                ak.to_numpy(self.hit_E[idx]),)


## Graph construction

In [152]:
def create_graph(x,y,z,l,E,k=8):
    #Remove the zero energy hits
    mask=E>0
    x,y,z,l,E=x[mask],y[mask],z[mask],l[mask],E[mask]
    coords=np.column_stack((x,y,z))
    print(coords.shape) #
    N=len(coords)
    if N<k+1:
        return None
    #Node features
    node_features=torch.from_numpy(np.column_stack((x,y,z,l,E))).float()
    #Global KNN(vectorized- once per event)
    knn=NearestNeighbors(n_neighbors=k+1,algorithm="kd_tree").fit(coords)
    knn_dist,knn_idx=knn.kneighbors(coords)
    print(knn_dist[8,:],knn_idx[8,:])#

    knn_idx=knn_idx[:,1:] # Removes the self hit
    #Layer adjancency/layer window mask
    #valid=np.abs(l[:,None]-l[knn_idx])<=layer_window
    #knn_idx=np.where(valid,knn_idx,-1)
    #Vectorised edge constructions
    src=np.repeat(np.arange(N),k)
    dst=knn_idx.reshape(-1)
    mask=dst>=0
    src,dst=src[mask],dst[mask]
    edge_index=torch.tensor(np.vstack([src,dst]),dtype=torch.long)
    #Edge attributes
    #l_diff=np.abs(l[src]-l[dst])
    #E_diff=np.abs(E[src]-E[dst])
    #edge_attr=torch.tensor(l_diff[:,None],E_diff[:,None],dtype=torch.float)
    return Data(x=node_features,edge_index=edge_index)

In [153]:
graph0=dataset[0]

(404, 3)
[0.         0.89       2.115      3.005      3.005      3.895
 4.22028731 5.12       5.91      ] [ 8  7  9 10  5  4 11 12  3]


In [157]:
dataset[0].x[6]

(404, 3)
[0.         0.89       2.115      3.005      3.005      3.895
 4.22028731 5.12       5.91      ] [ 8  7  9 10  5  4 11 12  3]


tensor([-3.8798e+00, -3.3600e+00,  2.0562e+01,  6.0000e+00,  1.9298e-05])

## PyTorch Geometric dataset

In [48]:
class HGCALDataset(Dataset):
    def __init__(self, root_reader, k=8):
        #super().__init__()
        self.reader = root_reader
        self.k = k
    def __len__(self):
        return self.reader.n_events
    def __getitem__(self, idx):
        event = self.reader.get_event(idx)
        graph= create_graph(*event, k=self.k)
        return graph

In [28]:
root_reader=RootEventReader(root_file,tree_name)

In [145]:
dataset=HGCALDataset(root_reader,k=8)

In [150]:
graph0=dataset[0]

(404, 3)
(404, 9) (404, 9)


In [124]:
print(graph0)

Data(x=[404, 5], edge_index=[2, 3232])


In [125]:
def graph_neighbors(data, node_idx):
    edge_index = data.edge_index
    src = edge_index[0]
    dst = edge_index[1]

    mask = src == node_idx
    neighbors = dst[mask]
    central_pos = data.x[node_idx, :3]         # first 3 columns are coordinates
    neighbor_pos = data.x[neighbors, :3]

    distances = torch.norm(neighbor_pos - central_pos, dim=1)

    return neighbors.cpu().numpy(), distances.cpu().numpy()



In [141]:
ci = 8
gnbrs = graph_neighbors(graph0, ci)

print("Graph neighbors:", gnbrs[0])
print("Number of neighbors:", len(gnbrs[0]))


Graph neighbors: [ 7  9 10  5  4 11 12  3]
Number of neighbors: 8


In [142]:
#x0=graph0.x[:,1]
#y0=graph0.x[:,2]

In [143]:
coords = graph0.x[:, :3].numpy()   # x, y, z
#ci = 8

dist = np.linalg.norm(coords - coords[ci], axis=1)
manual_nbrs = np.argsort(dist)[1:9]   # exclude itself

print("Manual nearest neighbors:", manual_nbrs)
print("Distances:",np.sort(dist))

Manual nearest neighbors: [ 7  9  5 10  4 11 12  3]
Distances: [  0.          0.8899994   2.1150017   3.0049992   3.005001    3.8949986
   4.220288    5.120001    5.91        5.9474716   6.01        6.1134686
   6.315323    6.5208592   6.799999    7.925001    8.460877    8.815
   8.815001    9.704999   10.076598   10.730001   11.620001   13.535002
  14.425001   16.340002   16.378342   16.492825   17.230001   17.374998
  17.978191   19.325003   19.35743    20.455      20.48564    20.668522
  22.83       22.857456   22.857456   22.857456   22.857456   22.857456
  22.91227    22.939627   23.26541    23.960001   23.986164   23.986164
  23.986164   23.986164   24.038404   24.064482   24.142546   24.194447
  24.297918   24.297918   26.0995     26.123522   26.123522   26.123522
  26.123522   26.123522   26.171495   26.171495   26.171495   26.171495
  26.171495   26.195448   26.195448   26.195448   26.267181   26.267181
  26.267181   26.314892   26.314892   26.481207   27.105501   27.12863
  2

In [129]:
# data.x shape: [num_nodes, 4] (x, y, z, E)
z_all = graph0.x[:, 3]   # extract column 2 â†’ z
print(z_all)

tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  6.,  7.,  8.,  9., 10., 10., 11., 11.,
        12., 12., 12., 13., 13., 14., 14., 15., 16., 17., 18., 19., 19., 19.,
        20., 20., 20., 21., 21., 22., 22., 22., 23., 23., 23., 23., 23., 23.,
        23., 23., 23., 24., 24., 24., 24., 24., 24., 24., 24., 24., 24., 24.,
        25., 25., 25., 25., 25., 25., 25., 25., 25., 25., 25., 25., 25., 25.,
        25., 25., 25., 25., 25., 25., 26., 26., 26., 26., 26., 26., 26., 26.,
        26., 26., 26., 26., 26., 26., 26., 26., 26., 26., 27., 27., 27., 27.,
        27., 27., 27., 27., 27., 27., 27., 27., 27., 27., 27., 27., 27., 27.,
        27., 27., 27., 27., 28., 28., 28., 28., 28., 28., 28., 28., 28., 28.,
        28., 28., 28., 28., 28., 28., 28., 28., 28., 28., 28., 28., 28., 28.,
        28., 28., 29., 29., 29., 29., 29., 29., 29., 29., 29., 29., 29., 29.,
        29., 29., 29., 29., 29., 29., 29., 29., 29., 29., 29., 29., 29., 29.,
        29., 29., 29., 29., 29., 29., 29., 29., 29., 29., 29., 2

In [135]:
node_idx = 8  # example
l_node = graph0.x[node_idx, 3]
print(f"Node {node_idx} layer: {l_node}")


Node 8 layer: 8.0
