In [None]:
import json
import uproot
import awkward as ak
import vector
import numpy as np
import matplotlib.pyplot as plt
import warnings

# plotting params
plt.rcParams.update(
    {
        "figure.figsize": (10, 6),
        "axes.grid": True,
        "grid.alpha": 0.6,
        "grid.linestyle": "--",
        "font.size": 14,
        "figure.dpi": 200,
    }
)

# Suppress a harmless warning from the vector library with awkward arrays
warnings.filterwarnings("ignore", message="Passing an awkward array to a ufunc")

# Register the vector library with awkward array
ak.behavior.update(vector.backends.awkward.behavior)

# --- CONFIGURATION ---
with open("hh-bbbb-obj-config.json", "r") as config_file:
    CONFIG = json.load(config_file)

In [None]:
# eta-phi plots of L1 jets and L1 puppi constituents

from data_loading_helpers import load_and_prepare_data

events = load_and_prepare_data(
    "/Users/adityatandon/Documents/College/Physics/Year 4/Neural SBI/root-obj-perf/data/hh4b_puppi_pf/hh4b/data_*.root",
    "Events",
    ["L1puppiJetSC4NG", "L1BarrelExtPuppi", "GenPart"],  # Jets  # Candidates  # Labels
    max_events=None,
)
l1_col = "L1puppiJetSC4NG"
l1_puppi_col = "L1BarrelExtPuppi"
for idx in range(3):
    print(f"\nEvent {idx}:")
    l1_jets = events[l1_col][idx]
    l1_cands = events[l1_puppi_col][idx]
    print(f"  L1 Jets: {len(l1_jets)}")
    print(f"  L1 Candidates: {len(l1_cands)}")

    plt.scatter(
        l1_cands.phi, l1_cands.eta, s=5, c="gray", alpha=0.5, label="L1 Candidates"
    )
    plt.scatter(l1_jets.phi, l1_jets.eta, s=50, c="red", marker="x", label="L1 Jets")
    plt.xlabel("Phi")
    plt.ylabel("Eta")
    plt.title(f"L1 Jets and Candidates - Event {idx}")
    plt.legend()
    plt.show()

    l1_jets_expanded = l1_jets

    for jdx in range(min(5, len(l1_jets))):
        jet = l1_jets[jdx]
        print(f"    Jet {jdx}: pt={jet.pt:.2f}, eta={jet.eta:.2f}, phi={jet.phi:.2f}")
        # Find constituents within DeltaR < 0.4
        jet_vec = vector.obj(pt=jet.pt, eta=jet.eta, phi=jet.phi)
        constituents_in_jet = []
        for cdx in range(len(l1_cands)):
            cand = l1_cands[cdx]
            cand_vec = vector.obj(pt=cand.pt, eta=cand.eta, phi=cand.phi)
            delta_r = jet_vec.deltaR(cand_vec)
            if delta_r < 0.4:
                constituents_in_jet.append(cand)
        print(f"      Constituents in DeltaR<0.4: {len(constituents_in_jet)}")

In [None]:
import torch
from tqdm import tqdm
import importlib
import parT

importlib.reload(parT)
from parT import ParticleTransformer

device = "mps"

model = ParticleTransformer(
    input_dim=17,
    embed_dim=128,
    num_pairwise_feat=4,
    num_heads=8,
    num_layers=5,
    num_cls_layers=3,
    dropout=0.0,
    num_classes=1,
    use_batch_norm=False,  # Disable BatchNorm for small batch overfit test
)
print(f"Number of model parameters: {sum(p.numel() for p in model.parameters())}")

data = np.load("l1_training_data.npz")

# Load data - check if particle_mask exists (for backwards compatibility)
all_x = torch.tensor(data["x"], dtype=torch.float32)
all_y = torch.tensor(data["y"], dtype=torch.float32)
if "particle_mask" in data.files:
    all_mask = torch.tensor(data["particle_mask"], dtype=torch.bool)
    print("Loaded particle mask from dataset")
else:
    # Fallback: infer mask from non-zero particles (E != 0)
    all_mask = all_x[..., 0] != 0
    print("No particle_mask in dataset, inferring from non-zero energy")

# Filter for unique samples (avoid duplicates with conflicting labels)
print(f"Total samples in dataset: {len(all_x)}")
unique_indices = []
seen_hashes = set()
for i in range(len(all_x)):
    h = hash(all_x[i].numpy().tobytes())
    if h not in seen_hashes:
        seen_hashes.add(h)
        unique_indices.append(i)

print(f"Unique samples: {len(unique_indices)}")

# Use first 10 unique samples
x = all_x[unique_indices[:20]]
y = all_y[unique_indices[:20]]
mask = all_mask[unique_indices[:20]]

print("Labels: ", y)
# pos_weight = torch.sum(1 - y) / (torch.sum(y) + 1e-6)
pos_weight = torch.tensor(1.0)
print("Pos weight: ", pos_weight)

model.to(device)
x = x.to(device)
y = y.to(device)
mask = mask.to(device)

In [None]:
# overfit testing

optim = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optim, T_0=800, T_mult=2)
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)

loss_vals = []
lr_vals = []
grad_norms = []

model.train()
for epoch in tqdm(range(400)):
    optim.zero_grad()
    outputs = model(x, particle_mask=mask).squeeze()
    loss = criterion(outputs, y)
    loss.backward()
    # break
    optim.step()
    # scheduler.step()
    lr_vals.append(optim.param_groups[0]["lr"])
    loss_vals.append(loss.item())
    total_norm = 0
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm**0.5
    grad_norms.append(total_norm)

plt.plot(loss_vals)
plt.show()
plt.plot(grad_norms)
plt.show()

In [None]:
model.eval()
outputs = torch.nn.functional.sigmoid(model(x)).squeeze().detach().cpu().numpy()
print("Labels: ", y.cpu().numpy())
print("Outputs: ", np.round(outputs, 3))
plt.hist(outputs, bins=40, range=(0, 1))
plt.show()

In [None]:
print("Gradient norms for model parameters:")
for i, (name, param) in enumerate(model.named_parameters()):
    if param.grad is not None:
        grad_norm = param.grad.data.norm(2).item()
        print(f"Param {i} [{name}]: {grad_norm}")
    else:
        print(f"Param {i} [{name}]: No gradient")

In [None]:
# Check if samples are actually distinct

n_identical = 0
print("Checking sample uniqueness...")
for i in range(len(x)):
    for j in range(i + 1, len(x)):
        diff = (x[i] - x[j]).abs().sum().item()
        if diff < 1e-3:
            print(f"Samples {i} and {j} are nearly identical! diff={diff}")
            n_identical += 1
print(f"Number of nearly identical samples: {n_identical}")

# Check sample statistics
print("\nSample-wise statistics (mean of features):")
for i in range(len(x)):
    print(
        f"Sample {i}: mean={x[i].mean().item():.4f}, std={x[i].std().item():.4f}, label={y[i].item()}"
    )

# Check how many non-zero particles per sample
print("\nNon-zero particles per sample:")
for i in range(len(x)):
    nonzero = (x[i].abs().sum(dim=-1) > 1e-6).sum().item()
    print(f"Sample {i}: {nonzero} non-zero particles out of {x.shape[1]}")

if n_identical > 0:
    print("\nDATA ISSUE: Identical samples have different labels!")
    print("Please check the dataset for consistency.")
else:
    print("\nAll samples are unique.")