In [2]:
%cd ..

/home1/giorgian/projects/trigger-detection-pipeline/sPHENIX/tracking-GNN


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [8]:
from dataclasses import replace
import numpy as np
import os
import torch
import matplotlib.pyplot as plt
import os.path
import sys
import logging
import pickle
from collections import defaultdict
from sklearn.linear_model import LinearRegression
from models.bgn_st_tracking import GNNSegmentClassifier
from icecream import ic
from numpy.linalg import inv
import sklearn.metrics as metrics
from datasets import get_data_loaders
from tqdm.notebook import tqdm
from itertools import islice

In [74]:
def get_track_endpoints(hits, good_layers):
    # Assumption: all tracks have at least 1 hit
    # If it has one hit, first_hit == last_hit for that track
    # hits shape: (n_tracks, 5, 3)
    # good_layers shape: (n_tracks, 5)
    min_indices = good_layers * np.arange(5) + (1 - good_layers) * np.arange(5, 10)
    indices = np.expand_dims(np.argmin(min_indices, axis=-1), -1)
    indices = np.expand_dims(indices, axis=-2)
    first_hits = np.take_along_axis(hits, indices, axis=-2)
    max_indices = good_layers * np.arange(5, 10) + (1 - good_layers) * np.arange(5)
    indices = np.expand_dims(np.argmax(max_indices, axis=-1), -1)
    indices = np.expand_dims(indices, axis=-2)
    last_hits = np.take_along_axis(hits, indices, axis=-2)
    return first_hits.squeeze(1), last_hits.squeeze(1)

def get_predicted_pz(track_hits, good_layers, radius):
    hits = track_hits.reshape(-1, 5, 3)
    first_hit, last_hit = get_track_endpoints(hits, good_layers)
    dz = (last_hit[:, -1] - first_hit[:, -1])/100
    chord2 = ((last_hit[:, 0] - first_hit[:, 0]) ** 2 + (last_hit[:, 1] - first_hit[:, 1]) ** 2) / 10000
    r2 = 2*radius**2
    with np.errstate(invalid='ignore'):
        dtheta = np.arccos((r2 - chord2) / (r2 + (r2 == 0)))
    dtheta += (dtheta == 0)
    return np.nan_to_num(dz / dtheta)

def matmul_3D(A, B):
    return np.einsum('lij,ljk->lik', A, B)


def get_approximate_radii(track_hits, good_layers, n_layers):
    x_indices = [3*j for j in range(5)]
    y_indices = [3*j+1 for j in range(5)]
    r = np.zeros(track_hits.shape[0])
    centers = np.zeros((track_hits.shape[0], 2))
    for n_layer in range(3, 5 + 1):
        complete_tracks = track_hits[n_layers == n_layer]
        hit_indices = good_layers[n_layers == n_layer]
        if complete_tracks.shape[0] == 0:
            continue

        A = np.ones((complete_tracks.shape[0], n_layer, 3))
        x_values = complete_tracks[:, x_indices]
        x_values = x_values[hit_indices].reshape(complete_tracks.shape[0], n_layer)

        y_values = complete_tracks[:, y_indices]
        y_values = y_values[hit_indices].reshape(complete_tracks.shape[0], n_layer)
        A[:, :, 0] = x_values
        A[:, :, 1] = y_values

        y = - x_values**2 - y_values**2
        y = y.reshape((y.shape[0], y.shape[1], 1))
        AT = np.transpose(A, axes=(0, 2, 1))
        c = matmul_3D(matmul_3D(inv(matmul_3D(AT, A)), AT), y)[..., 0]
        r[n_layers == n_layer] = np.sqrt(c[:, 0]**2 + c[:, 1]**2 - 4*c[:, 2])/200
        centers[n_layers == n_layer] = np.stack([-c[:, 0]/2, -c[:, 1]/2], axis=-1)

    #test = get_approximate_radius(track_hits, n_layers == 5)
    #assert np.allclose(test, r[n_layers == 5])

    return r, centers

def get_length(start, end):
    return np.sqrt(np.sum((start - end)**2, axis=1))


def port_event(batch, batch_output, ip_output, trigger_output, output_file):
    """
    Ported function to construct track_hits from batch.x_intt, batch.x_mvtx, and batch.edge_index.
    
    Parameters:
      batch: an object with attributes:
            - x_intt: shape (N, 10) [each row has two hits, each hit with (r, phi, z, layer_id, n_pixels)]
            - x_mvtx: shape (M, 5) [each row is (r, phi, z, layer_id, n_pixels)]
            - edge_index: shape (2, num_edges) linking x_intt indices (first row) to x_mvtx indices (second row)
            - interaction_point: an iterable giving the true interaction point (e.g. [x, y, z])
            - trigger: a boolean flag for trigger
      batch_output: predictions for the edges (use batch_output > 0 to decide if an edge is “true”)
      ip_output: model’s predicted interaction point (to be saved as interaction_point_pred)
      trigger_output: model’s predicted trigger output (to be saved as trigger_pred)
      output_file: destination path for the npz output file.
    """
    # Mapping from raw layer_id to track layer (0-indexed):
    # 0 -> 0, 1 -> 1, 2 -> 2, 3 -> 3, 4 -> 3, 5 -> 4, 6 -> 4.
    layer_map = {0: 0, 1: 1, 2: 2, 3: 3, 4: 3, 5: 4, 6: 4}
    
    x_intt = batch.x_intt.detach().cpu().numpy()   # shape (N, 10)
    x_mvtx = batch.x_mvtx.detach().cpu().numpy()    # shape (M, 5)
    edge_index = batch.edge_index  # shape (2, num_edges)
    
    num_tracks = x_intt.shape[0]
    # Prepare an array to hold the track vectors: 5 layers x 3 coordinates = 15 per track.
    track_hits = np.zeros((num_tracks, 15), dtype=np.float32)
    
    # Create a mapping from each x_intt hit (track candidate) to the list of associated x_mvtx indices.
    # We assume edge_index[0] contains indices into x_intt and edge_index[1] indices into x_mvtx.
    track_to_mvtx = {i: [] for i in range(num_tracks)}
    
    # Filter edge_index based on the mask (true edges: batch_output > 0)
    true_mask = batch_output > 0
    true_edges = edge_index[:, true_mask]  # shape (2, num_true_edges)
    
    # Group associated mvtx indices by their corresponding intt index.
    for intt_idx, mvtx_idx in zip(true_edges[0], true_edges[1]):
        intt_idx = int(intt_idx)
        mvtx_idx = int(mvtx_idx)
        track_to_mvtx[intt_idx].append(mvtx_idx)
    
    # Process each track (each row in x_intt)
    for i in range(num_tracks):
        # Dictionary to collect (x,y,z) hits for each track layer (0 through 4).
        hits_per_layer = {layer: [] for layer in range(5)}
        
        # --- Process the two hits from x_intt ---
        # First hit: entries 0 to 4.
        hit1 = x_intt[i,  [0, 1, 2, 6, 8]]  # (r, phi, z, layer_id, n_pixels)
        r, phi, z, _, layer_id = hit1
        r *= 3
        z *= 3
        x_coord = r * np.cos(phi)
        y_coord = r * np.sin(phi)
        mapped_layer = layer_map[int(layer_id)]
        hits_per_layer[mapped_layer].append([x_coord, y_coord, z])
        
        # Second hit: entries 5 to 10.
        hit2 = x_intt[i, [3, 4, 5, 7, 9]]
        r, phi, z, _, layer_id = hit2
        r *= 3
        z *= 3
        x_coord = r * np.cos(phi)
        y_coord = r * np.sin(phi)
        mapped_layer = layer_map[int(layer_id)]
        hits_per_layer[mapped_layer].append([x_coord, y_coord, z])
        
        # --- Process associated x_mvtx hits ---
        for mvtx_idx in track_to_mvtx[i]:
            hit = x_mvtx[mvtx_idx]  # shape (5,)
            r, phi, z, _, layer_id = hit
            r *= 3
            z *= 3
            x_coord = r * np.cos(phi)
            y_coord = r * np.sin(phi)
            mapped_layer = layer_map[int(layer_id)]
            hits_per_layer[mapped_layer].append([x_coord, y_coord, z])
        
        # --- Average hits per layer ---
        # The track vector is arranged as:
        # [layer0_x, layer0_y, layer0_z, layer1_x, layer1_y, layer1_z, ..., layer4_x, layer4_y, layer4_z]
        for layer in range(5):
            if hits_per_layer[layer]:
                # Compute the mean coordinate for this layer.
                avg_coord = np.mean(np.array(hits_per_layer[layer]), axis=0)
                track_hits[i, 3*layer:3*layer+3] = avg_coord
            # If no hits exist for a layer, the corresponding entries remain 0.
    
    # --- Compute helper variables ---
    # Reshape to (num_tracks, 5, 3) and create a boolean mask indicating which layers have at least one hit.
    good_layers = np.any(track_hits.reshape(num_tracks, 5, 3) != 0, axis=-1)
    n_layers = np.sum(good_layers, axis=-1)
    
    # Use batch.interaction_point (the true collision vertex) for output.
    ip = tuple(float(x) for x in batch.interaction_point[0])
    
    # Call the helper functions (assumed to be implemented elsewhere)
    radii, centers = get_approximate_radii(track_hits, good_layers, n_layers)
    p_z = get_predicted_pz(track_hits, good_layers, radii)
    
    # --- Save output ---
    # Save the numpy file with the required keys. Note that we are not shuffling the tracks.
    np.savez(
        output_file,
        collision_vertex=ip,
        tracks=track_hits,
        radii=radii,
        p_z=p_z,
        centers=centers,
        trigger=batch.trigger.cpu().numpy()[0],
        trigger_pred=trigger_output.detach().cpu().numpy(),
        interaction_point_pred=ip_output.detach().cpu().numpy()
    )


In [63]:
batch.x_intt[:, [0, 1, 2, 6, 8]],batch.x_intt[:, [3, 4, 5, 7, 9]]

(tensor([[ 7.1390e+00,  1.2387e+00, -1.0972e+01,  1.0000e+00,  3.0000e+00],
         [ 7.5817e+00, -2.9723e+00, -1.0972e+01,  2.0000e+00,  3.0000e+00],
         [ 7.5708e+00, -2.9554e+00, -1.0972e+01,  1.0000e+00,  3.0000e+00],
         [ 7.6388e+00, -1.8861e-01,  6.8275e+00,  1.0000e+00,  3.0000e+00],
         [ 7.6416e+00, -1.8352e-01,  6.8275e+00,  7.0000e+00,  3.0000e+00],
         [ 7.4943e+00,  3.4347e-01,  1.0028e+01,  1.0000e+00,  3.0000e+00],
         [ 7.6591e+00, -2.3313e+00,  1.1628e+01,  1.0000e+00,  3.0000e+00],
         [ 7.5010e+00,  1.7474e-01,  1.3628e+01,  1.0000e+00,  3.0000e+00],
         [ 7.4757e+00,  3.0029e-01,  1.5628e+01,  2.0000e+00,  3.0000e+00],
         [ 8.1788e+00, -2.9820e+00, -1.4572e+01,  1.0000e+00,  4.0000e+00],
         [ 8.1550e+00, -1.1833e-02,  4.2755e-01,  5.0000e+00,  4.0000e+00],
         [ 8.1824e+00, -8.0037e-02,  6.8275e+00,  6.0000e+00,  4.0000e+00],
         [ 8.2764e+00, -2.0345e+00,  1.5628e+01,  2.0000e+00,  4.0000e+00],
         [ 8

In [44]:
DEVICE = "cuda:0"

In [64]:
# create model and load checkpoint
model_result_folder = '../tracking_results/agnn/agnn-lr0.0008127498598898657-b24-d64-PReLU-gi1-ln-False-n50000/experiment_2025-03-25_09:52:45/'
config_file = model_result_folder + '/config.pkl'
config = pickle.load(open(config_file, 'rb'))
data_config = config.get('data')

model_config = config.get('model', {})
model_config.pop('loss_func')
model_config.pop('name')
model = GNNSegmentClassifier(**model_config).to(DEVICE)

def load_checkpoint(checkpoint_file, model, optimizer=None):
    checkpoint = torch.load(checkpoint_file, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model'])
    if optimizer != None:
        optimizer.load_state_dict(checkpoint['optimizer'])
        return model, optimizer
    return model

# load_checkpoint
checkpoint_dir = os.path.join(model_result_folder, 'checkpoints')
checkpoint_file = sorted([os.path.join(checkpoint_dir, f) for f in os.listdir(checkpoint_dir) if f.startswith('model_checkpoint')])
checkpoint_file = checkpoint_file[-1]
print(checkpoint_file)
model = load_checkpoint(checkpoint_file, model)
print('Successfully reloaded!')

../tracking_results/agnn/agnn-lr0.0008127498598898657-b24-d64-PReLU-gi1-ln-False-n50000/experiment_2025-03-25_09:52:45/checkpoints/model_checkpoint_017.pth.tar
Successfully reloaded!


  checkpoint = torch.load(checkpoint_file, map_location=torch.device('cpu'))


In [34]:
data_config['batch_size'] = 1
data_config['n_train'] = 1
data_config['n_valid'] = 34000
#data_config['input_dir2'] = '/ssd1/giorgian/hits-data-august-2022-ctypes/trigger/1'
#data_config['force_inputdir2_nontrigger'] = True

In [35]:
train_data_loader, valid_data_loader = get_data_loaders(distributed=False, rank=0, n_ranks=0, **data_config)


In [68]:
trigger_output_dir = '/disks/disk1/giorgian/bbar-data-march-2025/trigger/'
nontrigger_output_dir = '/disks/disk1/giorgian/bbar-data-march-2025/nontrigger/'
output_dirs = (trigger_output_dir, nontrigger_output_dir)

for output_dir in output_dirs:
    os.makedirs(output_dir, exist_ok=True)


In [75]:
for batch in tqdm(valid_data_loader):
    # Run the model to get predictions
    batch = batch.to(DEVICE)
    batch_output, ip_output, trigger_output = model(batch)

    # Determine the output filename and directory based on batch.filename[0] and batch.trigger[0].
    fname = batch.filename[0]
    if batch.trigger[0]:
        output_file = os.path.join(trigger_output_dir, os.path.basename(fname))
    else:
        output_file = os.path.join(nontrigger_output_dir, os.path.basename(fname))

    # Process the batch and save the output.
    port_event(batch, batch_output, ip_output, trigger_output, output_file)

  0%|          | 0/68000 [00:00<?, ?it/s]

In [61]:
batch.interaction_point

tensor([[1.5534e-02, 1.2005e-03, 4.0199e+01]], device='cuda:0')