In [1]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import uproot
import torch
import awkward as ak
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.cm as cm
import ot

In [2]:
from sklearn.metrics import pairwise_distances_argmin
from sklearn.cluster import kmeans_plusplus

def energy_weighted_kmeans(event, n_clusters=6, max_iter=100, tol=1e-4, plot=True, show_message=False):
    """
    Performs energy-weighted K-Means clustering on particle production vertices.

    Parameters:
        event: awkward.Array    - Event data with true particle-level information.
                                 Must contain fields: 'prod_x_true', 'prod_y_true', 'prod_z_true', 'p_true'
        n_clusters: int         - Number of clusters to form (default: 6)
        max_iter: int           - Maximum number of iterations for convergence (default: 100)
        tol: float              - Convergence threshold on centroid shift (default: 1e-4)
        plot: bool              - If True, plot 3D clustering result (default: True)
        show_message: bool      - If True, print message upon convergence (default: False)

    Behavior:
        - Initializes centroids using KMeans++ from sklearn
        - Assigns particles to nearest centroid based on Euclidean distance
        - Updates centroids as energy-weighted mean of assigned points
        - Reinitializes empty clusters with random points
        - Stops iterating when centroid shift is below tolerance

    Returns:
        labels: ndarray (N,)         - Cluster label assigned to each particle
        centroids: ndarray (n x 3)   - Final centroid positions in 3D space
    """


    np.random.seed(0)
    x = ak.to_numpy(event['prod_x_true'])
    y = ak.to_numpy(event['prod_y_true'])
    z = ak.to_numpy(event['prod_z_true'])
    
    X = np.stack((x, y, z), axis=1)  # shape (N, 3)
    # ! Not actually the energy, but momentum
    energy = ak.to_numpy(event['p_true'])
    
    N = X.shape[0]

    # Initialize centroids from points
    centroids, _ = kmeans_plusplus(X, n_clusters=n_clusters, random_state=0)

    for i in range(max_iter):
        # Assignment step (standard, unweighted distance)
        labels = pairwise_distances_argmin(X, centroids)

        # Update step (energy-weighted mean)
        new_centroids = np.zeros_like(centroids)
        for k in range(n_clusters):
            mask = labels == k
            if np.sum(mask) == 0:
                # Reinitialize empty cluster to a random point
                new_centroids[k] = X[np.random.randint(0, N)]
            else:
                weights = energy[mask]
                points = X[mask]
                new_centroids[k] = np.average(points, axis=0, weights=weights)

        # Convergence check
        shift = np.linalg.norm(centroids - new_centroids)
        centroids = new_centroids
        if shift < tol:
            if show_message:
                print(f"Converged after {i} iterations")
            break
            
    if plot:
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        scatter = ax.scatter(x, y, z, c=labels, cmap='plasma', s=10)
        ax.scatter(centroids[:, 0], centroids[:, 1], centroids[:, 2], color='black', marker='x', s=100)
        plt.title("Energy-Weighted K-Means Clustering")
        plt.show()
    
    return labels, centroids


In [10]:
def compute_ewc_from_clustered_event(event):
    """
    Compute energy-weighted centroids from an event where particles are grouped by cluster,
    and include total pt per cluster.

    Parameters:
        event: awkward.Record
            - Must contain fields: 'cl_cell_xCells', 'cl_cell_yCells', 'cl_cell_zCells', 'cl_cell_pt'

    Returns:
        result: numpy array (n_clusters, 4)
            - Columns: [x_centroid, y_centroid, z_centroid, total_pt]
    """
    centroids_with_pt = []

    n_clusters = len(event['cl_cell_xCells'])

    for i in range(n_clusters):
        total_pt = np.array(event['cl_pt'][i])
        if total_pt < 1e4:
            continue
        x = ak.to_numpy(event['cl_cell_xCells'][i])
        y = ak.to_numpy(event['cl_cell_yCells'][i])
        z = ak.to_numpy(event['cl_cell_zCells'][i])
        energy = ak.to_numpy(event['cl_cell_pt'][i])

        if len(energy) == 0:
            centroids_with_pt.append([np.nan, np.nan, np.nan, 0.0])
            continue

        points = np.stack((x, y, z), axis=1)
        weighted_mean = np.average(points, axis=0, weights=energy)
        #total_pt = np.sum(energy)

        centroids_with_pt.append(list(weighted_mean) + [total_pt])

    return np.array(centroids_with_pt)

In [21]:
def create_clustered_truth_array(truth_array):
    """
    Applies energy-weighted K-Means clustering to each event and groups true particle features by cluster.

    Parameters:
        truth_array: awkward.Array - Array of events with particle-level truth information.
                                        Each event must contain fields: 
                                        'prod_x_true', 'prod_y_true', 'prod_z_true',
                                        'calo_x_true', 'calo_y_true', 'calo_z_true',
                                        'px_true', 'py_true', 'pz_true',
                                        'p_true', 'times_ATLAS_true', 'n_true'

    Behavior:
        - Runs energy-weighted K-Means clustering on each event with `n_clusters` clusters
        - Extracts a defined set of fields per cluster
        - Computes transverse momentum (pt = sqrt(px² + py²)) per particle and stores it as 'cell_pt'
        - Skips clusters with no assigned particles
        - Aggregates results into a list of structured dictionaries, one per event

    Returns:
        clustered_truth_array: awkward.Array - Structured array of clustered particle data per event,
                                                including cluster-wise field groupings and per-cluster pt
    """

    n_clusters = 6
    clustered_truth_particles = []

    # Fields to extract and group per cluster
    cluster_fields = [
        'prod_x_true', 'prod_y_true', 'prod_z_true',
        'calo_x_true', 'calo_y_true', 'calo_z_true',
        'px_true', 'py_true', 'pz_true',
        'p_true', 'times_ATLAS_true'
    ]

    for event_idx, event in enumerate(truth_array):
        labels, centroids = energy_weighted_kmeans(event, n_clusters=n_clusters, plot=False, show_message=False)

        # Initialize field containers
        clustered_data = {field: [] for field in cluster_fields}
        cell_pt = []
        cluster_pt = []
        average_time = []

        for cluster_id in range(n_clusters):
            mask = labels == cluster_id
            if np.sum(mask) == 0:
                continue

            try:
                for field in cluster_fields:
                    clustered_data[field].append(event[field][mask])
            except IndexError:
                print(f"Skipping inconsistent event at index {event_idx}")
                continue

            # Compute transverse momentum (pt) per particle in this cluster
            px = event['px_true'][mask]
            py = event['py_true'][mask]
            times = event['times_ATLAS_true'][mask]

            # ? Should I use p instead of pt?
            pt = np.sqrt(px**2 + py**2)
            cell_pt.append(pt)
            cluster_pt.append(np.sum(pt))
            average_time.append(np.mean(times))

    # ! I should add the time average for the cluster here and then fix the last cell
        clustered_truth_particles.append({
            'n_true': event['n_true'],
            **clustered_data,
            'cell_pt': cell_pt,
            'centroid_x': [c[0] for c in centroids],
            'centroid_y': [c[1] for c in centroids],
            'centroid_z': [c[2] for c in centroids],
            'cl_pt': cluster_pt,
            'cl_time': average_time
        })
        # är det inte konstigt att den convergar med så få iteratiorner?
    return ak.Array(clustered_truth_particles)

In [5]:
def extract_features_awkward(arr, features=("x", "y", "z", "pt")):
    """
    Converts an Awkward Array of objects with specified feature keys into a NumPy array.

    Parameters:
        arr: awkward.Array - Input Awkward Array with feature fields
        features: tuple - Names of the features to extract, in order

    Returns:
        np.ndarray of shape (n_objects, len(features)) if flat,
        or (n_events, variable-length, len(features)) if nested.
    """
    columns = [ak.to_numpy(arr[feature]) for feature in features]
    stacked = np.stack(columns, axis=1)

    return stacked

In [6]:
from scipy.spatial.distance import cdist
import ot

def match_showers_to_clusters(
    X,
    Y,
    w_pos=1.0,
    w_pt=0.05,
    reg=3,
    threshold=0.0001,
    use_dummy=True,
    dummy_cost=5.0,
    return_plan=False,
    verbose=False,
    plot=False
):
    """
    Matches particle showers to detector clusters using Sinkhorn Optimal Transport.

    Parameters:
        X: ndarray (n x d) - Features of true showers (e.g. [x, y, z, pt])
        Y: ndarray (m x d) - Features of measured clusters
        w_pos: float       - Weight for spatial features
        w_pt: float    - Weight for pt feature
        threshold: float   - Threshold on transport matrix to consider a match
        use_dummy: bool    - Whether to add a dummy cluster to absorb unmatched showers
        dummy_cost: float  - Cost assigned to dummy cluster
        return_plan: bool  - If True, return the full transport plan P

    Returns:
        matches: list of (i, j, score) - List of matched shower/cluster pairs with confidence scores
        (Optional) P: transport plan matrix (n x m(+1) if dummy used)
    """
    # Normalize features
    all_data = np.vstack([X, Y])
    mean = all_data.mean(axis=0)
    std = all_data.std(axis=0)
    X_norm = (X - mean) / std
    Y_norm = (Y - mean) / std

    # Compute cost matrix
    C = (
        w_pos * cdist(X_norm[:, :3], Y_norm[:, :3], metric='sqeuclidean') +
        w_pt * cdist(X_norm[:, 3:4], Y_norm[:, 3:4], metric='sqeuclidean')
    )
    # Normalize cost matrix
    C /= C.std()

    if verbose:
        print("Min:", C.min(), "Mean:", C.mean(), "Max:", C.max())
        print("C range:", C.min(), C.mean(), C.max(), " | reg =", reg)
        print("exp(-C.max()/reg) =", np.exp(-C.max() / reg))

    # Add dummy cluster if enabled
    if use_dummy:
        dummy_col = dummy_cost * np.ones((X.shape[0], 1))
        C = np.hstack([C, dummy_col])

    # Define source and target distributions
    a = np.ones(X.shape[0]) / X.shape[0]
    b = np.ones(C.shape[1]) / C.shape[1]

    # Compute Sinkhorn transport plan
    P = ot.sinkhorn(a, b, C, reg=reg)

    # print dummy mass
    dummy_mass = P[:, -1].sum()
    if verbose:
        print(f"Total dummy mass: {dummy_mass:.2f}")

    if plot:
        plt.imshow(C, cmap="plasma", aspect="auto")
        plt.colorbar()
        plt.title("Cost Matrix")

    # Extract matches above threshold
    matches = []
    for i in range(P.shape[0]):
        for j in range(P.shape[1]):
            if P[i, j] > threshold and (not use_dummy or j < Y.shape[0]):
                matches.append((i, j, P[i, j]))

    return (matches, P) if return_plan else matches

In [15]:
import torch
from torch_geometric.data import Data
from torch_geometric.nn import radius_graph
from torch_geometric.utils import dense_to_sparse

#from sklearn.neighbors import NearestNeighbors

def generate_graph_list(clustered_truth_array, calo_array, radius_edges=False):
    graph_list = []

    for event_idx in range(len(clustered_truth_array)):
        matched_clusters = []

        try:
            # === Get feature arrays for matching
            X = extract_features_awkward(clustered_truth_array[event_idx], features=("centroid_x", "centroid_y", "centroid_z", "cl_pt"))
            Y = compute_ewc_from_clustered_event(calo_array[event_idx])

            if len(X) == 0 or len(Y) == 0:
                continue

            # === Run Sinkhorn
            matches, P = match_showers_to_clusters(X, Y, reg=0.2, threshold=0.011, dummy_cost=0.5, return_plan=True, verbose=False)

            for i, j, p_ij in matches:
                if j >= Y.shape[0]:
                    continue

                true_time = np.array(clustered_truth_array[event_idx]["cl_time"][i])

                cell_x = ak.to_numpy(calo_array[event_idx]['cl_cell_xCells'][j])
                cell_y = ak.to_numpy(calo_array[event_idx]['cl_cell_yCells'][j])
                cell_z = ak.to_numpy(calo_array[event_idx]['cl_cell_zCells'][j])
                cell_E = ak.to_numpy(calo_array[event_idx]['cl_cell_E'][j])
                cell_T = ak.to_numpy(calo_array[event_idx]['cl_cell_TimeCells'][j])

                # --- Step 2: Build node feature matrix
                node_features = torch.tensor(np.stack([cell_x, cell_y, cell_z, cell_E, cell_T], axis=1), dtype=torch.float)

                # --- Step 3: Build edges using Radius on (x, y, z)
                if radius_edges:
                    pos = node_features[:, :3] # Gets (x, y, z) from features
                    edge_index = radius_graph(pos, r=150)
                else:
                    num_nodes = node_features.size(0)
                    adj = torch.ones((num_nodes, num_nodes)) - torch.eye(num_nodes)  # full connection minus self-loops
                    edge_index, _ = dense_to_sparse(adj)

                # --- Step 4: Cluster time (target value)
                """### Need to be changed for actual target values"""
                cluster_times = np.zeros(len(cell_T))        
                target = torch.tensor(true_time, dtype=torch.float)

                # --- Step 5: Build graph
                graph = Data(x=node_features, edge_index=edge_index, y=target)
                graph_list.append(graph)

        except Exception as e:
            print(f"Skipping event {event_idx}: {e}")
            continue
    return graph_list

In [25]:
FILE_NAME_1 = "LLP_true_times.caloCells_combined.1.root"
FILE_NAME_2 = "LLP_true_times.caloCells_combined.2.root"
STEP_SIZE = 500

DATA_TREE = "caloCells"
TRUTH_TREE = "LLPTruthTree"

# Gets the the latest tree versions
calo_tree_1 = uproot.open(FILE_NAME_1 + ":" + DATA_TREE)
truth_tree_1 = uproot.open(FILE_NAME_1 + ":" + TRUTH_TREE)

"""calo_tree_2 = uproot.open(FILE_NAME_2 + ":" + DATA_TREE)
truth_tree_2 = uproot.open(FILE_NAME_2 + ":" + TRUTH_TREE)"""

graph_list_all = []  # this will hold all graphs

i = 0
for calo_batch, truth_batch in zip(
    calo_tree_1.iterate(filter_name=["cl_cell_*", "cl_pt"], library="ak", step_size=STEP_SIZE),
    truth_tree_1.iterate(library="ak", step_size=STEP_SIZE)
):
    print(f"iteration: {i*STEP_SIZE} - {(i+1)*STEP_SIZE}")

    clustered_truth_array = create_clustered_truth_array(truth_batch)
    graph_list = generate_graph_list(clustered_truth_array, calo_batch)

    graph_list_all.extend(graph_list)  # add batch to master list
    print(f"Total graphs so far: {len(graph_list_all)}")
    i+=1

"""for calo_batch, truth_batch in zip(
    calo_tree_2.iterate(filter_name=["cl_cell_*", "cl_pt"], library="ak", step_size=STEP_SIZE),
    truth_tree_2.iterate(library="ak", step_size=STEP_SIZE)
):
    print(f"Event iteration: {i*STEP_SIZE} - {(i+1)*STEP_SIZE}")

    clustered_truth_array = create_clustered_truth_array(truth_batch)
    graph_list = generate_graph_list(clustered_truth_array, calo_batch)

    graph_list_all.extend(graph_list)  # add batch to master list
    print(f"Total graphs so far: {len(graph_list_all)}")
    i+=1"""

torch.save(graph_list_all, "graph_dataset.pt")

iteration: 0 - 500
Total graphs so far: 3384
iteration: 500 - 1000
Total graphs so far: 6927
iteration: 1000 - 1500
Total graphs so far: 10047
iteration: 1500 - 2000
Total graphs so far: 13407
iteration: 2000 - 2500
Total graphs so far: 16479
iteration: 2500 - 3000
Total graphs so far: 19638
iteration: 3000 - 3500
Total graphs so far: 22425
iteration: 3500 - 4000
Total graphs so far: 25569
iteration: 4000 - 4500
Total graphs so far: 28421
iteration: 4500 - 5000
Total graphs so far: 31078
iteration: 5000 - 5500
Total graphs so far: 33830
iteration: 5500 - 6000
Total graphs so far: 37022
iteration: 6000 - 6500
Skipping inconsistent event at index 434
Skipping inconsistent event at index 434
Skipping inconsistent event at index 434
Skipping event 434: all input arrays must have the same shape
Total graphs so far: 39844
iteration: 6500 - 7000
Total graphs so far: 42936
iteration: 7000 - 7500
Total graphs so far: 45874
iteration: 7500 - 8000
Total graphs so far: 49121
iteration: 8000 - 8500

In [None]:
graph_dataset = torch.load("graph_dataset.pt", weights_only=False)
print(len(graph_dataset))

612


In [None]:
from torch_geometric.data import Data

def graphs_are_equal(list1, list2, verbose=False):
    if len(list1) != len(list2):
        if verbose: print(f"Length mismatch: {len(list1)} vs {len(list2)}")
        return False

    for i, (g1, g2) in enumerate(zip(list1, list2)):
        if type(g1) != type(g2):
            if verbose: print(f"Type mismatch at index {i}")
            return False

        for attr in ['x', 'edge_index', 'y']:
            t1 = getattr(g1, attr, None)
            t2 = getattr(g2, attr, None)

            if t1 is None and t2 is None:
                continue
            if t1 is None or t2 is None:
                if verbose: print(f"Missing attribute '{attr}' at index {i}")
                return False
            if not torch.equal(t1, t2):
                if verbose: print(f"Mismatch in '{attr}' at index {i}")
                return False

    return True

same = graphs_are_equal(graph_list_all, graph_dataset, verbose=True)
print("Graphs match!" if same else "Graphs do not match.")