In [None]:
import torch
import numpy as np
from analysis.core import Core, calculate_B_C, evaluate_quantum_results_with_uncertainties, build_histograms, \
    build_results
import vector

vector.register_awkward()
import awkward as ak

In [None]:
raw_data = torch.load('data/TT2L-pretrain.pt')

## Truth Analysis

- `t  -> W+  b,     W+  -> l+  v`
- `t~ -> W-  b~,    W-  -> l-  v~`

---

extra truth variables columns: `[pt, eta, phi, mass, energy, index, mother1, mother2]`

In [None]:
raw_truth = {
    key.replace('EXTRA/lhe/', ''): torch.cat([item[key] for item in raw_data]).numpy()
    for key in raw_data[0].keys()
    if key.startswith('EXTRA/lhe/')
}


def sanity_and_merge(pairs, data):
    merged = {}
    for a, b, new_key in pairs:
        valid_a = ~np.isnan(data[a]).all(axis=1)
        valid_b = ~np.isnan(data[b]).all(axis=1)
        if np.any(valid_a & valid_b):
            raise ValueError(f"Conflict: both {a} and {b} present in same event.")
        valid_a = valid_a[:, None]  # broadcast
        merged[new_key] = np.where(valid_a, data[a], data[b])
    return merged


# Define pairs to merge: (key_a, key_b, merged_key)
pairs_to_merge = [
    ('e+', 'mu+', 'l+'),
    ('e-', 'mu-', 'l-'),
    ('nu(e)', 'nu(mu)', 'v'),
    ('nu(e)~', 'nu(mu)~', 'v~'),
]

# Run merging with sanity check
truth_data = sanity_and_merge(pairs_to_merge, raw_truth)
for k in ['W+', 'W-', 'b', 'b~', 't', 't~']:
    truth_data[k] = raw_truth[k]

for k in truth_data.keys():
    truth_data[k] = vector.zip({
        'pt': truth_data[k][:, 0],
        'eta': truth_data[k][:, 1],
        'phi': truth_data[k][:, 2],
        'mass': truth_data[k][:, 3],
    })

truth_core = Core(
    main_particle_1=truth_data['t'],
    main_particle_2=truth_data['t~'],
    child1=truth_data['l+'],
    child2=truth_data['l-'],
)

truth_result = truth_core.analyze()
truth_result = truth_result.query('mass < 400')
truth_hists = build_histograms(truth_result)

result, result_up, result_down = calculate_B_C(truth_hists, kappas=(1.0, -1.0))
D = -(result['C_nn'] + result['C_rr'] + result['C_kk'])
final = evaluate_quantum_results_with_uncertainties(result, result_up, result_down)

## Recon Analysis

### point cloud structure

- `full_input_point_cloud`: (N, 18, 7) array of particles in the event
    - columns: `[energy, pt, eta, phi, btag, isLepton, charge]`
- `t1 > b1, l1` l1 is 11 and 13
- `t2 > b2, l2` l2 is -11 and -13

> From LHE, lepton sign is correct, which means `l+` is for e+ and muon+; while for assignment, `l+` is for e- and muon-

In [None]:
def classify_TT2L(point_cloud, assignment_target):
    """
    Classify particles in a TT2L topology.

    Parameters:
        point_cloud: (N, num_particles, num_features)
        assignment_target: tuple/list with 2 index arrays (each (N, 2)) for the two groups

    Returns:
        Dict with b1, b2, l1, l2 reconstructions.
    """

    idx = np.arange(point_cloud.shape[0])[:, None]

    # Two targets (e.g., top1 and top2)
    t1_target = assignment_target[0]  # (N, 2)
    t2_target = assignment_target[1]  # (N, 2)

    # Gather candidates
    t1_recon_tmp = point_cloud[idx, t1_target, :].numpy()  # (N, 2, F)
    t2_recon_tmp = point_cloud[idx, t2_target, :].numpy()

    N = t1_recon_tmp.shape[0]

    def select_object(recon_tmp, mask_feature_idx, threshold=0.5):
        """
        For each event, select the candidate if feature > threshold.
        At most one is expected to be True. Return (N, F) with NaN for none.
        """
        mask = recon_tmp[:, :, mask_feature_idx] > threshold  # (N, 2)
        idx_first = np.argmax(mask, axis=1)  # (N,)
        has_true = np.any(mask, axis=1)  # (N,)
        result = recon_tmp[np.arange(N), idx_first]  # (N, F)
        result[~has_true] = np.nan

        result = vector.zip({
            'pt': np.expm1(result[:, 1]),
            'eta': result[:, 2],
            'phi': result[:, 3],
            # 'mass': result[:, 4],
            'energy': np.expm1(result[:, 0]),
        })

        return result

    # B-jet: feature[4] > 0.5
    b1_recon = select_object(t1_recon_tmp, 4)
    b2_recon = select_object(t2_recon_tmp, 4)

    # Lepton: feature[5] > 0.5
    l1_recon = select_object(t1_recon_tmp, 5)
    l2_recon = select_object(t2_recon_tmp, 5)

    return {
        'b1_recon': b1_recon,
        'b2_recon': b2_recon,
        'l1_recon': l1_recon,
        'l2_recon': l2_recon,
    }


def extract_batch_assignments(batch, classify_fn, process="TT2L"):
    pred = batch['assignment_prediction']
    target = batch['assignment_target']
    target_mask = batch['assignment_target_mask']

    process_match = {
        'num_lepton': batch['full_input_point_cloud'].sum(axis=1)[:, 5].numpy().astype(np.int32),
        'num_bjet': batch['full_input_point_cloud'].sum(axis=1)[:, 4].numpy().astype(np.int32),
    }

    target_list = target[process]
    pred_process = pred[process]['best_indices']
    mask_process = target_mask[process]

    process_match.update({
        **classify_fn(batch['full_input_point_cloud'], target_list)
    })

    for p_idx, (assignment_target, assignment_prediction, assignment_target_mask) in enumerate(
            zip(target_list, pred_process, mask_process)):
        assignment_target = assignment_target.numpy()
        assignment_prediction = assignment_prediction.numpy()
        assignment_target_mask = assignment_target_mask.numpy()

        # Matching: true if all particles in the group are correctly assigned
        matched = (assignment_target == assignment_prediction)
        matched = matched.all(axis=1)  # along particle axis

        process_match[f"{process}_{p_idx}"] = matched
        process_match[f"{process}_{p_idx}_mask"] = assignment_target_mask

    return process_match


dfs = []
for batch in raw_data:
    out = extract_batch_assignments(batch, classify_fn=classify_TT2L)
    dfs.append(out)

# Instead of pd.concat, build one big awkward.Array
recon_data = ak.zip({
    k: ak.concatenate([out[k] for out in dfs])
    for k in dfs[0].keys()
})

truth_particle = {
    'b1': truth_data['b~'],
    'b2': truth_data['b'],
    'l1': truth_data['l-'],
    'l2': truth_data['l+'],
    't1': truth_data['t~'],
    't2': truth_data['t'],
}

for p, v in truth_particle.items():
    recon_data = ak.with_field(recon_data, v, f'{p}_truth')

truth_result = Core(
    main_particle_1=recon_data.t2_truth,
    main_particle_2=recon_data.t1_truth,
    child1=recon_data.l2_truth,
    child2=recon_data.l1_truth,
).analyze()

recon_result = Core(
    main_particle_1=recon_data.t2_truth,
    main_particle_2=recon_data.t1_truth,
    child1=recon_data.l2_recon,
    child2=recon_data.l1_recon,
).analyze()

full = build_results(truth_result, recon_result)

### Unfolding

In [None]:
from analysis.unfold import main

bin_nums = 10 + 1
bin_edges = {
    "m_tt": np.array([0, 400, 500, 800, np.inf]),
    "B_Ak": np.linspace(-1, 1, bin_nums),
    "B_An": np.linspace(-1, 1, bin_nums),
    "B_Ar": np.linspace(-1, 1, bin_nums),
    "B_Bk": np.linspace(-1, 1, bin_nums),
    "B_Bn": np.linspace(-1, 1, bin_nums),
    "B_Br": np.linspace(-1, 1, bin_nums),
    "C_kk": np.linspace(-1, 1, bin_nums),
    "C_kn": np.linspace(-1, 1, bin_nums),
    "C_kr": np.linspace(-1, 1, bin_nums),
    "C_nk": np.linspace(-1, 1, bin_nums),
    "C_nn": np.linspace(-1, 1, bin_nums),
    "C_nr": np.linspace(-1, 1, bin_nums),
    "C_rk": np.linspace(-1, 1, bin_nums),
    "C_rn": np.linspace(-1, 1, bin_nums),
    "C_rr": np.linspace(-1, 1, bin_nums),
}

unfolded = main(full, bin_edges=bin_edges)

# 1) Reshape all unfolded arrays for each variable (except m_tt)
mtt_nbins = len(bin_edges['m_tt']) - 1

unfolded_temp = {
    key: {
        'edges': edges,
        'counts': unfolded[f'{key}_recon_unfold_content'].to_numpy().reshape(mtt_nbins, len(edges) - 1),
        'errors': unfolded[f'{key}_recon_unfold_error'].to_numpy().reshape(mtt_nbins, len(edges) - 1),
    }
    for key, edges in bin_edges.items()
    if key != 'm_tt'
}

# 2) Split by m_tt bins using clean dict comprehension
unfolded_hists = {
    f"m_tt < {mtt_right}": {
        key: {
            'edges': data['edges'],
            'counts': data['counts'][idx],
            'errors': data['errors'][idx],
        }
        for key, data in unfolded_temp.items()
    }
    for idx, mtt_right in enumerate(bin_edges['m_tt'][1:])
}

result, result_up, result_down = calculate_B_C(unfolded_hists['m_tt < 400.0'], kappas=(1.0, -1.0))
D = -(result['C_nn'] + result['C_rr'] + result['C_kk'])
final = evaluate_quantum_results_with_uncertainties(result, result_up, result_down)

## Plotting

In [None]:
from pathlib import Path
import os
from downstreams.plotting.unfolding import plot_uncertainty_with_ratio

bins_mtt = bin_edges["m_tt"]
mtt_labels = [
    r"$m_{t\bar{t}} < 400$",
    r"$400 < m_{t\bar{t}} < 500$",
    r"$500 < m_{t\bar{t}} < 800$",
    r"$m_{t\bar{t}} \geq 800$"
]

common_labels = {
    # 6 B terms: B_{A,B}{n,r,k}
    **{
        f"B_{which}{axis}": {
            "name": rf"$\cos\theta^{{{which}}}_{{{axis}}}$",
            "labels": [f"bin {i}" for i in range(len(bin_edges[f"B_{which}{axis}"]) - 1)],
        }
        for which in ['A', 'B']
        for axis in ['n', 'r', 'k']
    },
    # 9 C terms: C_{axis1}{axis2} for axis1, axis2 in {n,r,k}
    **{
        f"C_{ax1}{ax2}": {
            "name": rf"$\cos\theta^A_{{{ax1}}} \cos\theta^B_{{{ax2}}}$",
            "labels": [f"bin {i}" for i in range(len(bin_edges[f"C_{ax1}{ax2}"]) - 1)],
        }
        for ax1 in ['n', 'r', 'k']
        for ax2 in ['n', 'r', 'k']
    },
}

for var, var_cfg in common_labels.items():
    methods = [
        {
            "name": r"EveNet - Pretrain", "color": "green",
            # "data": unfolded[f"{var}_recon_unfold_error"]
            "data": unfolded[f"{var}_recon_unfold_content"]
        },
        {
            "name": r"EveNet - Scratch", "color": "green",
            "data": unfolded[f"{var}_recon_unfold_error"]
        },
    ]

    plot_uncertainty_with_ratio(
        mtt_labels, var_cfg["labels"], var_cfg['name'], methods,
        ratio_baseline_name=r"EveNet - Pretrain",
        p_dir=Path(os.getcwd()) / "plots",
        save_name=f"unfolded_{var}.pdf",
        ratio_baseline_max=0.25,
        ratio_baseline_min=-0.05,
        ratio_y_label=r"Improvement to EveNet - Pretrain",
    )

