# Preamble: load data

In [1]:
import os
import uproot
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import random
import awkward as ak # Using awkward array for easier handling of jagged data
import time # For timing steps

def load_root_file(file_path, branches=None, print_branches=False):
    all_branches = {}
    with uproot.open(file_path) as file:
        tree = file["tree"]
        # Load all ROOT branches into array if not specified
        if branches is None:
            branches = tree.keys()
        # Option to print the branch names
        if print_branches:
            print("Branches:", tree.keys())
        # Each branch is added to the dictionary
        for branch in branches:
            try:
                all_branches[branch] = (tree[branch].array(library="np"))
            except uproot.KeyInFileError as e:
                print(f"KeyInFileError: {e}")
        # Number of events in file
        all_branches['event'] = tree.num_entries
    return all_branches

branches_list = [
    't5_innerRadius',
    't5_bridgeRadius',
    't5_outerRadius',
    't5_pt',
    't5_eta',
    't5_phi',
    't5_isFake',
    't5_t3_idx0',
    't5_t3_idx1',

    't5_t3_fakeScore1',
    't5_t3_promptScore1',
    't5_t3_displacedScore1',
    't5_t3_fakeScore2',
    't5_t3_promptScore2',
    't5_t3_displacedScore2',

    't5_pMatched',
    't5_sim_vxy',
    't5_sim_vz',
    't5_matched_simIdx'
]

branches_list += [
    'pLS_eta',
    'pLS_etaErr',
    'pLS_phi',
    'pLS_matched_simIdx',
    'pLS_circleCenterX',
    'pLS_circleCenterY',
    'pLS_circleRadius',
    'pLS_ptIn',
    'pLS_ptErr',
    'pLS_px',
    'pLS_py',
    'pLS_pz',
    'pLS_isQuad',
    'pLS_isFake'
]

# Hit-dependent branches
suffixes = ['r', 'z', 'eta', 'phi', 'layer']
branches_list += [f't5_t3_{i}_{suffix}' for i in [0, 2, 4] for suffix in suffixes]

file_path = "pls_t5_embed.root"
branches = load_root_file(file_path, branches_list, print_branches=True)

Branches: ['sim_pt', 'sim_eta', 'sim_phi', 'sim_pca_dxy', 'sim_pca_dz', 'sim_q', 'sim_event', 'sim_pdgId', 'sim_vx', 'sim_vy', 'sim_vz', 'sim_trkNtupIdx', 'sim_TC_matched', 'sim_TC_matched_mask', 'tc_pt', 'tc_eta', 'tc_phi', 'tc_type', 'tc_isFake', 'tc_isDuplicate', 'tc_matched_simIdx', 'sim_dummy', 'tc_dummy', 'pT5_matched_simIdx', 'pT5_hitIdxs', 'sim_pT5_matched', 'pT5_pt', 'pT5_eta', 'pT5_phi', 'pT5_isFake', 't5_sim_vxy', 't5_sim_vz', 'pT5_isDuplicate', 'pT5_score', 'pT5_layer_binary', 'pT5_moduleType_binary', 'pT5_matched_pt', 'pT5_rzChiSquared', 'pT5_rPhiChiSquared', 'pT5_rPhiChiSquaredInwards', 'sim_pT3_matched', 'pT3_pt', 'pT3_isFake', 'pT3_isDuplicate', 'pT3_eta', 'pT3_phi', 'pT3_score', 'pT3_foundDuplicate', 'pT3_matched_simIdx', 'pT3_hitIdxs', 'pT3_pixelRadius', 'pT3_pixelRadiusError', 'pT3_matched_pt', 'pT3_tripletRadius', 'pT3_rPhiChiSquared', 'pT3_rPhiChiSquaredInwards', 'pT3_rzChiSquared', 'pT3_layer_binary', 'pT3_moduleType_binary', 'sim_pLS_matched', 'pLS_matched_simIdx

In [2]:
z_max = np.max([np.max(event) for event in branches[f't5_t3_4_z']])
r_max = np.max([np.max(event) for event in branches[f't5_t3_4_r']])
eta_max = 2.5
phi_max = np.pi
n_events = np.shape(branches['t5_pt'])[0]

print(f'Z max: {z_max}, R max: {r_max}, Eta max: {eta_max}')

def delta_phi(phi1, phi2):
    delta = phi1 - phi2
    # Adjust delta to be within the range [-pi, pi]
    if delta > np.pi:
        delta -= 2 * np.pi
    elif delta < -np.pi:
        delta += 2 * np.pi
    return delta

Z max: 267.2349853515625, R max: 110.10993957519531, Eta max: 2.5


# Preamble: T5 features

In [3]:
pMATCHED_THRESHOLD = 0.        # keep if t5_pMatched ≥ this
print(f"\nBuilding T5 features  (pMatched ≥ {pMATCHED_THRESHOLD}) …")

features_per_event    = []
eta_per_event         = []
displaced_per_event = []
sim_indices_per_event = []

kept_tot, init_tot = 0, 0
for ev in range(n_events):

    n_t5 = len(branches['t5_t3_idx0'][ev])
    init_tot += n_t5
    if n_t5 == 0:
        continue

    feat_evt = []
    eta_evt  = []
    sim_evt  = []
    disp_evt = []

    for i in range(n_t5):
        if branches['t5_pMatched'][ev][i] < pMATCHED_THRESHOLD:
            continue

        idx0 = branches['t5_t3_idx0'][ev][i]
        idx1 = branches['t5_t3_idx1'][ev][i]

        # hit-level quantities -------------------------------------------------
        eta1 = (branches['t5_t3_0_eta'][ev][idx0])
        eta2 = abs(branches['t5_t3_2_eta'][ev][idx0])
        eta3 = abs(branches['t5_t3_4_eta'][ev][idx0])
        eta4 = abs(branches['t5_t3_2_eta'][ev][idx1])
        eta5 = abs(branches['t5_t3_4_eta'][ev][idx1])

        phi1 = branches['t5_t3_0_phi'][ev][idx0]
        phi2 = branches['t5_t3_2_phi'][ev][idx0]
        phi3 = branches['t5_t3_4_phi'][ev][idx0]
        phi4 = branches['t5_t3_2_phi'][ev][idx1]
        phi5 = branches['t5_t3_4_phi'][ev][idx1]

        z1 = abs(branches['t5_t3_0_z'][ev][idx0])
        z2 = abs(branches['t5_t3_2_z'][ev][idx0])
        z3 = abs(branches['t5_t3_4_z'][ev][idx0])
        z4 = abs(branches['t5_t3_2_z'][ev][idx1])
        z5 = abs(branches['t5_t3_4_z'][ev][idx1])

        r1 = branches['t5_t3_0_r'][ev][idx0]
        r2 = branches['t5_t3_2_r'][ev][idx0]
        r3 = branches['t5_t3_4_r'][ev][idx0]
        r4 = branches['t5_t3_2_r'][ev][idx1]
        r5 = branches['t5_t3_4_r'][ev][idx1]

        inR  = branches['t5_innerRadius' ][ev][i]
        brR  = branches['t5_bridgeRadius'][ev][i]
        outR = branches['t5_outerRadius' ][ev][i]

        s1_fake   = branches['t5_t3_fakeScore1'     ][ev][i]
        s1_prompt = branches['t5_t3_promptScore1'   ][ev][i]
        s1_disp   = branches['t5_t3_displacedScore1'][ev][i]
        d_fake    = branches['t5_t3_fakeScore2'     ][ev][i] - s1_fake
        d_prompt  = branches['t5_t3_promptScore2'   ][ev][i] - s1_prompt
        d_disp    = branches['t5_t3_displacedScore2'][ev][i] - s1_disp

        f = [
            eta1 / eta_max,
            np.cos(phi1),
            np.sin(phi1),
            z1 / z_max,
            r1 / r_max,

            eta2 - abs(eta1),
            delta_phi(phi2, phi1),
            (z2 - z1) / z_max,
            (r2 - r1) / r_max,

            eta3 - eta2,
            delta_phi(phi3, phi2),
            (z3 - z2) / z_max,
            (r3 - r2) / r_max,

            eta4 - eta3,
            delta_phi(phi4, phi3),
            (z4 - z3) / z_max,
            (r4 - r3) / r_max,

            eta5 - eta4,
            delta_phi(phi5, phi4),
            (z5 - z4) / z_max,
            (r5 - r4) / r_max,

            1.0 / inR,
            1.0 / brR,
            1.0 / outR,

            s1_fake, s1_prompt, s1_disp,
            d_fake,  d_prompt,  d_disp
        ]
        feat_evt.append(f)
        eta_evt.append(eta1)
        disp_evt.append(branches['t5_sim_vxy'][ev][i])

        # first (or only) matched sim-index, -1 if none -----------------------
        simIdx_list = branches['t5_matched_simIdx'][ev][i]
        sim_evt.append(simIdx_list[0] if len(simIdx_list) else -1)

    # push to global containers ----------------------------------------------
    if feat_evt:                                # skip events with no survivors
        features_per_event.append(np.asarray(feat_evt, dtype=np.float32))
        eta_per_event.append(np.asarray(eta_evt,  dtype=np.float32))
        displaced_per_event.append(np.asarray(disp_evt, dtype=np.float32))
        sim_indices_per_event.append(np.asarray(sim_evt, dtype=np.int64))
        kept_tot += len(feat_evt)

print(f"\nKept {kept_tot} / {init_tot} T5s "
      f"({kept_tot/init_tot*100:.2f} %) that passed the pMatched cut.")
print(f"Total events with ≥1 kept T5: {len(features_per_event)}")


Building T5 features  (pMatched ≥ 0.0) …

Kept 3630781 / 3630781 T5s (100.00 %) that passed the pMatched cut.
Total events with ≥1 kept T5: 500


# Preamble: PLS features

In [4]:
KEEP_FRAC_PLS = 0.40
print(f"\nBuilding pLS features …")

pLS_features_per_event    = []
pLS_eta_per_event         = []
pLS_sim_indices_per_event = []

kept_tot_pls, init_tot_pls = 0, 0
for ev in range(n_events):
    n_pls = len(branches['pLS_eta'][ev])
    init_tot_pls += n_pls
    if n_pls == 0:
        continue

    feat_evt, eta_evt, sim_evt = [], [], []

    for i in range(n_pls):
        if branches['pLS_isFake'][ev][i]:
            continue
        if np.random.random() > KEEP_FRAC_PLS:
            continue

        # ――― hit‑level quantities -------------------------------------------
        eta = branches['pLS_eta'][ev][i]
        etaErr = branches['pLS_etaErr'][ev][i]
        phi = branches['pLS_phi'][ev][i]
        circleCenterX = np.abs(branches['pLS_circleCenterX'][ev][i])
        circleCenterY = np.abs(branches['pLS_circleCenterY'][ev][i])
        circleRadius = branches['pLS_circleRadius'][ev][i]
        ptIn = branches['pLS_ptIn'][ev][i]
        ptErr = branches['pLS_ptErr'][ev][i]
        isQuad = branches['pLS_isQuad'][ev][i]

        # ――― build feature vector -------------------------------------------
        f = [
            eta/4.0,
            etaErr/.00139,
            np.cos(phi),
            np.sin(phi),
            1.0 / ptIn,
            np.log10(ptErr),
            isQuad,
            np.log10(circleCenterX),
            np.log10(circleCenterY),
            np.log10(circleRadius),
        ]

        feat_evt.append(f)
        eta_evt.append(eta)

        sim_list = branches['pLS_matched_simIdx'][ev][i]
        sim_evt.append(sim_list[0] if len(sim_list) else -1)

    # ――― store per‑event containers -----------------------------------------
    if feat_evt:              # skip events with no survivors
        pLS_features_per_event   .append(np.asarray(feat_evt, dtype=np.float32))
        pLS_eta_per_event        .append(np.asarray(eta_evt,  dtype=np.float32))
        pLS_sim_indices_per_event.append(np.asarray(sim_evt, dtype=np.int64))
        kept_tot_pls += len(feat_evt)

print(f"\nKept {kept_tot_pls} / {init_tot_pls} pLSs "
      f"({kept_tot_pls/init_tot_pls*100:.2f} %) that passed the selections.")
print(f"Total events with ≥1 kept pLS: {len(pLS_features_per_event)}")


Building pLS features …

Kept 3721929 / 11314631 pLSs (32.89 %) that passed the selections.
Total events with ≥1 kept pLS: 500


# T5-T5 pairs, for-loop

In [5]:
import time, random, math, numpy as np
from concurrent.futures import ProcessPoolExecutor, as_completed
from sklearn.model_selection import train_test_split

DELTA_R2_CUT = 0.02

# ---------------------------------------------------------------------------
def _delta_phi(phi1, phi2):
    """same helper you already defined, but inline for the worker"""
    d = phi1 - phi2
    if d > math.pi:
        d -= 2 * math.pi
    elif d < -math.pi:
        d += 2 * math.pi
    return d

# ---------------------------------------------------------------------------
def _pairs_single_event(evt_idx,
                        F, S, D,
                        max_sim, max_dis,
                        invalid_sim):
    """
    Worker run in a separate process.
    Returns two Python lists with the selected (i,j) indices per event.
    """
    t0 = time.time()
    n = F.shape[0]
    if n < 2:
        return evt_idx, [], []

    eta1 = F[:, 0] * eta_max
    phi1 = np.arctan2(F[:, 2], F[:, 1])

    sim_pairs, dis_pairs = [], []

    # similar pairs (same sim-index)
    buckets = {}
    for idx, s in enumerate(S):
        if s != invalid_sim:
            buckets.setdefault(s, []).append(idx)

    for lst in buckets.values():
        if len(lst) < 2:
            continue
        for a in range(len(lst) - 1):
            i = lst[a]
            for b in range(a + 1, len(lst)):
                j = lst[b]
                dphi = _delta_phi(phi1[i], phi1[j])
                dr2  = (eta1[i] - eta1[j])**2 + dphi**2
                if dr2 < DELTA_R2_CUT:
                    sim_pairs.append((i, j))

    # dissimilar pairs (different sim)
    for i in range(n - 1):
        si, ei, pi = S[i], eta1[i], phi1[i]
        for j in range(i + 1, n):
            # skip fake-fake pairs
            if si == invalid_sim and S[j] == invalid_sim:
                continue
            if (si == S[j]) and si != invalid_sim:
                continue
            dphi = _delta_phi(pi, phi1[j])
            dr2  = (ei - eta1[j])**2 + dphi**2
            if dr2 < DELTA_R2_CUT:
                dis_pairs.append((i, j))

    # down-sample
    if len(sim_pairs) > max_sim:
        sim_pairs = random.sample(sim_pairs, max_sim)
    if len(dis_pairs) > max_dis:
        dis_pairs = random.sample(dis_pairs, max_dis)


    dt = time.time() - t0
    print(f"[evt {evt_idx:4d}]  T5s={n:5d}  sim={len(sim_pairs):3d}  dis={len(dis_pairs):3d} | {dt:.1f} seconds")
    return evt_idx, sim_pairs, dis_pairs


# T5-T5 pairs, vectorized

In [6]:
import time
import numpy as np
import random

DELTA_R2_CUT = 0.02

def _pairs_single_event_vectorized(evt_idx,
                                   F, S, D,
                                   max_sim, max_dis,
                                   invalid_sim):
    t0 = time.time()
    n = F.shape[0]
    eta1, phi1 = F[:, 0] * eta_max, np.arctan2(F[:, 2], F[:, 1])

    # upper-triangle (non-diagonal) indices
    idx_l, idx_r = np.triu_indices(n, k=1)
    idxs_triu = np.stack((idx_l, idx_r), axis=-1)

    simidx_l = S[idx_l]
    simidx_r = S[idx_r]

    eta_l = eta1[idx_l]
    eta_r = eta1[idx_r]
    phi_l = phi1[idx_l]
    phi_r = phi1[idx_r]
    dphi = np.abs(phi_l - phi_r)
    dphi[dphi > np.pi] -= 2 * np.pi  # adjust to [-pi, pi]
    dr2 = (eta_l - eta_r)**2 + dphi**2

    dr2_valid = (dr2 < DELTA_R2_CUT)
    sim_idx_same = (simidx_l == simidx_r)
    sim_mask = dr2_valid & sim_idx_same & (simidx_l != invalid_sim)
    dis_mask = dr2_valid & ~sim_idx_same

    sim_pairs = idxs_triu[sim_mask]
    dis_pairs = idxs_triu[dis_mask]

    # down-sample
    if len(sim_pairs) > max_sim:
        sim_pairs = sim_pairs[random.sample(range(len(sim_pairs)), max_sim)]
    if len(dis_pairs) > max_dis:
        dis_pairs = dis_pairs[random.sample(range(len(dis_pairs)), max_dis)]

    dt = time.time() - t0
    print(f"[evt {evt_idx:4d}]  T5s={n:5d}  sim={len(sim_pairs):3d}  dis={len(dis_pairs):3d} | {dt:.1f} seconds vectorized")
    return evt_idx, sim_pairs, dis_pairs


# T5-T5 pairs, comparing for-loop and vectorized

In [7]:
import time, random, math, numpy as np
from concurrent.futures import ProcessPoolExecutor, as_completed

def create_t5_pairs_balanced_parallel(features_per_event,
                                      sim_indices_per_event,
                                      displaced_per_event,
                                      *,
                                      max_similar_pairs_per_event=100,
                                      max_dissimilar_pairs_per_event=450,
                                      invalid_sim_idx=-1,
                                      n_workers=None):

    n_evt = 5 # len(features_per_event)

    for evt_idx in range(n_evt):
        work_args = (evt_idx,
                    features_per_event[evt_idx],
                    sim_indices_per_event[evt_idx],
                    displaced_per_event[evt_idx],
                    max_similar_pairs_per_event,
                    max_dissimilar_pairs_per_event,
                    invalid_sim_idx,
                    )

        # run both versions
        random.seed(42)
        evt_idx, sim_pairs, dis_pairs = _pairs_single_event(*work_args)
        random.seed(42)
        evt_idx, sim_pairs_vec, dis_pairs_vec = _pairs_single_event_vectorized(*work_args)

        # compare
        sim_pairs_vec = [(i, j) for i, j in sim_pairs_vec] # de-numpify
        dis_pairs_vec = [(i, j) for i, j in dis_pairs_vec]
        sim_equal = (sim_pairs == sim_pairs_vec)
        dis_equal = (dis_pairs == dis_pairs_vec)
        sim_equiv = sorted(sim_pairs) == sorted(sim_pairs_vec)
        dis_equiv = sorted(dis_pairs) == sorted(dis_pairs_vec)
        print(f"[evt {evt_idx:4d}]  sim. pairs equal: {sim_equal}. Equivalent: {sim_equiv}")
        print(f"[evt {evt_idx:4d}]  dis. pairs equal: {dis_equal}. Equivalent: {dis_equiv}")
        print("")


# invoke
create_t5_pairs_balanced_parallel(
    features_per_event,
    sim_indices_per_event,
    displaced_per_event,
    max_similar_pairs_per_event    = 1e6,
    max_dissimilar_pairs_per_event = 1e6,
    invalid_sim_idx                = -1,
    n_workers                      = None,
)


[evt    0]  T5s= 3864  sim=25040  dis=19778 | 3.5 seconds
[evt    0]  T5s= 3864  sim=25040  dis=19778 | 0.2 seconds vectorized
[evt    0]  sim. pairs equal: False. Equivalent: True
[evt    0]  dis. pairs equal: True. Equivalent: True

[evt    1]  T5s= 4786  sim=39667  dis=34564 | 5.4 seconds
[evt    1]  T5s= 4786  sim=39667  dis=34564 | 0.3 seconds vectorized
[evt    1]  sim. pairs equal: False. Equivalent: True
[evt    1]  dis. pairs equal: True. Equivalent: True

[evt    2]  T5s= 4648  sim=28559  dis=32798 | 5.1 seconds
[evt    2]  T5s= 4648  sim=28559  dis=32798 | 0.3 seconds vectorized
[evt    2]  sim. pairs equal: False. Equivalent: True
[evt    2]  dis. pairs equal: True. Equivalent: True

[evt    3]  T5s= 5868  sim=38960  dis=50595 | 8.1 seconds
[evt    3]  T5s= 5868  sim=38960  dis=50595 | 0.5 seconds vectorized
[evt    3]  sim. pairs equal: False. Equivalent: True
[evt    3]  dis. pairs equal: True. Equivalent: True

[evt    4]  T5s= 5981  sim=43421  dis=46855 | 8.4 seconds
[e

# T5-PLS pairs, for-loop

In [8]:
# pairing hyper-parameters
DELTA_R2_CUT_PLS_T5 = 0.02
DISP_VXY_CUT       = 0.1
INVALID_SIM_IDX    = -1
MAX_SIM            = 1e6 # 1000
MAX_DIS            = 1e6 # 1000

In [9]:
import numpy as np
import random
from concurrent.futures import ProcessPoolExecutor, as_completed
import time


def _pairs_pLS_T5_single(evt_idx,
                         F_pLS, S_pLS,
                         F_T5,  S_T5, D_T5,
                         max_sim, max_dis,
                         invalid_sim):
    """
    Build similar / dissimilar pLS-T5 pairs for a single event,
    printing per-event summary.
    """
    t0 = time.time()
    n_p, n_t = F_pLS.shape[0], F_T5.shape[0]
    sim_pairs, dis_pairs = [], []

    # if either collection is empty, report zeros and bail
    if n_p == 0 or n_t == 0:
        print(f"[evt {evt_idx:4d}]  pLSs={n_p:5d}  T5s={n_t:5d}  sim={0:4d}  dis={0:4d}")
        return evt_idx, []

    # un-normalize eta and compute phi angles
    eta_p = F_pLS[:,0] * 4.0
    phi_p = np.arctan2(F_pLS[:,3], F_pLS[:,2])
    eta_t = F_T5[:,0] * eta_max
    phi_t = np.arctan2(F_T5[:,2], F_T5[:,1])

    # bucket T5 by sim-idx for similar
    buckets = {}
    for j,s in enumerate(S_T5):
        if s != invalid_sim:
            buckets.setdefault(s, []).append(j)
    for i,s in enumerate(S_pLS):
        if s == invalid_sim:
            continue
        for j in buckets.get(s, []):
            dphi = (phi_p[i] - phi_t[j] + np.pi) % (2*np.pi) - np.pi
            dr2  = (eta_p[i] - eta_t[j])**2 + dphi**2
            if dr2 < DELTA_R2_CUT_PLS_T5:
                sim_pairs.append((i,j))

    # find dissimilar (different sim-idx) pairs
    for i in range(n_p):
        for j in range(n_t):
            if S_pLS[i] == S_T5[j] and S_pLS[i] != invalid_sim:
                continue
            dphi = (phi_p[i] - phi_t[j] + np.pi) % (2*np.pi) - np.pi
            dr2  = (eta_p[i] - eta_t[j])**2 + dphi**2
            if dr2 < DELTA_R2_CUT_PLS_T5:
                dis_pairs.append((i,j))

    # down-sample to limits
    if len(sim_pairs) > max_sim:
        sim_pairs = random.sample(sim_pairs, max_sim)
    if len(dis_pairs) > max_dis:
        dis_pairs = random.sample(dis_pairs, max_dis)

    # print per-event summary
    dt = time.time() - t0
    print(f"[evt {evt_idx:4d}]  pLSs={n_p:5d}  T5s={n_t:5d}  "
          f"sim={len(sim_pairs):4d}  dis={len(dis_pairs):4d} | {dt:.1f} seconds vectorized")

    # pack into (feature, feature, label, displaced_flag)
    packed = []
    for i,j in sim_pairs:
        packed.append((F_pLS[i], F_T5[j], 0, D_T5[j] > DISP_VXY_CUT))
    for i,j in dis_pairs:
        packed.append((F_pLS[i], F_T5[j], 1, D_T5[j] > DISP_VXY_CUT))

    return evt_idx, packed


# T5-PLS pairs, vectorized

In [10]:
def _pairs_pLS_T5_single_vectorized(evt_idx,
                                    F_pLS, S_pLS,
                                    F_T5,  S_T5, D_T5,
                                    max_sim, max_dis,
                                    invalid_sim):
    """
    Build similar / dissimilar pLS-T5 pairs for a single event,
    printing per-event summary.
    """
    t0 = time.time()
    n_p, n_t = F_pLS.shape[0], F_T5.shape[0]
    sim_pairs, dis_pairs = [], []

    # if either collection is empty, report zeros and bail
    if n_p == 0 or n_t == 0:
        print(f"[evt {evt_idx:4d}]  pLSs={n_p:5d}  T5s={n_t:5d}  sim={0:4d}  dis={0:4d}")
        return evt_idx, []

    # un-normalize eta and compute phi angles
    eta_p = F_pLS[:,0] * 4.0
    phi_p = np.arctan2(F_pLS[:,3], F_pLS[:,2])
    eta_t = F_T5[:,0] * eta_max
    phi_t = np.arctan2(F_T5[:,2], F_T5[:,1])

    # make all possible pairs (i, j)
    idx_p, idx_t = np.indices( (n_p, n_t) )
    idx_p, idx_t = idx_p.flatten(), idx_t.flatten()

    # calculate angles
    dphi = (phi_p[idx_p] - phi_t[idx_t] + np.pi) % (2 * np.pi) - np.pi
    dr2 = (eta_p[idx_p] - eta_t[idx_t])**2 + dphi**2
    dr2_valid = (dr2 < DELTA_R2_CUT_PLS_T5)

    # compare sim indices
    simidx_p = S_pLS[idx_p]
    simidx_t = S_T5[idx_t]
    sim_idx_same = (simidx_p == simidx_t)

    # create masks for similar and dissimilar pairs
    sim_mask = dr2_valid & sim_idx_same & (simidx_p != invalid_sim)
    dis_mask = dr2_valid & ~sim_idx_same

    # get the pairs
    sim_pairs = np.column_stack((idx_p[sim_mask], idx_t[sim_mask]))
    dis_pairs = np.column_stack((idx_p[dis_mask], idx_t[dis_mask]))

    # down-sample
    if len(sim_pairs) > max_sim:
        sim_pairs = sim_pairs[random.sample(range(len(sim_pairs)), max_sim)]
    if len(dis_pairs) > max_dis:
        dis_pairs = dis_pairs[random.sample(range(len(dis_pairs)), max_dis)]

    # print per-event summary
    dt = time.time() - t0
    print(f"[evt {evt_idx:4d}]  pLSs={n_p:5d}  T5s={n_t:5d}  "
          f"sim. pairs={len(sim_pairs):4d}  dis. pairs={len(dis_pairs):4d} | {dt:.1f} seconds vectorized")


    # pack into (feature, feature, label, displaced_flag)
    packed = []
    for i,j in sim_pairs:
        packed.append((F_pLS[i], F_T5[j], 0, D_T5[j] > DISP_VXY_CUT))
    for i,j in dis_pairs:
        packed.append((F_pLS[i], F_T5[j], 1, D_T5[j] > DISP_VXY_CUT))

    return evt_idx, packed


# T5-PLS pairs, comparing for-loop and vectorized

In [11]:
n_evt = 5 # len(features_per_event)

for evt_idx in range(n_evt):
    works_args = (
        evt_idx,
        pLS_features_per_event[evt_idx],
        pLS_sim_indices_per_event[evt_idx],
        features_per_event[evt_idx],
        sim_indices_per_event[evt_idx],
        displaced_per_event[evt_idx],
        MAX_SIM, MAX_DIS, INVALID_SIM_IDX,
        )

    # run both versions
    random.seed(42)
    _, packed = _pairs_pLS_T5_single(*works_args)
    random.seed(42)
    _, packed_vec = _pairs_pLS_T5_single_vectorized(*works_args)

    # compare
    def tup_equal(a, b):
        return (np.array_equal(a[0], b[0]) and
                np.array_equal(a[1], b[1]) and
                a[2] == b[2] and
                np.array_equal(a[3], b[3]))
    
    def packed_equal(a, b):
        if len(a) != len(b):
            return False
        for (a_i, b_i) in zip(a, b):
            if not tup_equal(a_i, b_i):
                return False
        return True

    equal = packed_equal(packed, packed_vec)
    print(f"[evt {evt_idx:4d}]  pLS-T5 pairs equal: {equal}")
    print("")


[evt    0]  pLSs= 4837  T5s= 3864  sim=5566  dis=11187 | 12.4 seconds vectorized
[evt    0]  pLSs= 4837  T5s= 3864  sim. pairs=5566  dis. pairs=11187 | 0.4 seconds vectorized
[evt    0]  pLS-T5 pairs equal: True

[evt    1]  pLSs= 5349  T5s= 4786  sim=7691  dis=20348 | 17.1 seconds vectorized
[evt    1]  pLSs= 5349  T5s= 4786  sim. pairs=7691  dis. pairs=20348 | 0.6 seconds vectorized
[evt    1]  pLS-T5 pairs equal: True

[evt    2]  pLSs= 6306  T5s= 4648  sim=7480  dis=19931 | 19.5 seconds vectorized
[evt    2]  pLSs= 6306  T5s= 4648  sim. pairs=7480  dis. pairs=19931 | 0.7 seconds vectorized
[evt    2]  pLS-T5 pairs equal: True

[evt    3]  pLSs= 5885  T5s= 5868  sim=8893  dis=28644 | 23.1 seconds vectorized
[evt    3]  pLSs= 5885  T5s= 5868  sim. pairs=8893  dis. pairs=28644 | 0.8 seconds vectorized
[evt    3]  pLS-T5 pairs equal: True

[evt    4]  pLSs= 6463  T5s= 5981  sim=8938  dis=21937 | 25.9 seconds vectorized
[evt    4]  pLSs= 6463  T5s= 5981  sim. pairs=8938  dis. pairs=2193