# Perfromance metrics and plots for pre-trained models

In [None]:
#Adding all necessary imports

import numpy             as np
import matplotlib.pyplot as plt
import awkward           as ak
import matplotlib        as mpl
import vector
import sys
import os

os.chdir('/workspace/workDir/')
sys.path.append(os.getcwd())

from utils.misc import sparse_to_awkward, awkward_to_vector, sort_and_pad
from utils.jets import truth_match

## Loading from memory all files/models

In [None]:
from utils.image import pad

sample_names = ['vbfhh', 'jz']
conv_size    = 7

data = {}
for i_sample in sample_names:
    print(f'[INFO]: Loading {i_sample} sample')
    in_file = ak.from_parquet(f'/workspace/samples/{i_sample}/{i_sample}_test.parquet')
    data[i_sample] = {}
    data[i_sample]["Towers_PU"]   = pad(in_file["Towers"].to_numpy(), conv_size//2)
    data[i_sample]["Towers_NoPu"] = pad(in_file["Towers_NoPU"].to_numpy(), conv_size//2)
    data[i_sample]["GenJets"]     = in_file["GenJet"]
    data[i_sample]["Jet_NoPU"]    = in_file["Jet_NoPU"]
    data[i_sample]["Jets"]        = in_file["Jet"]

from tensorflow     import keras
from json           import load
from modules.layers import SlidingConeSum, LocalMaxMask

custom_objects = {
    'SlidingConeSum': SlidingConeSum,
    'LocalMaxMask'  : LocalMaxMask,
}

model      = keras.models.load_model('out/trained_model_sggF_SM_HH4b_train_w[7, 3]_b512_l2_nc[16, 32]_nd[32, 16].keras', custom_objects=custom_objects)
train_data = load(open('out/train_model_sggF_SM_HH4b_train_w[7, 3]_b512_l2_nc[16, 32]_nd[32, 16]_history.json', 'r'))

In [None]:
# Plot loss function

plt.figure(figsize=(10,5))
plt.plot(train_data['loss'], label='Training Loss')
plt.grid()
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training loss function evolution')
plt.legend()
plt.show()

In [None]:
# Running inference on validation sample

model_type = 'JetFinder'

if model_type == 'PU':
    predictions = model.predict(data['vbfhh']['Towers'])
    predictions = unpad(predictions, 1)

    tower_vectors = sparse_to_awkward(tower_to_vector(predictions))
    jets          = antikt_jets(tower_to_vectors, 10)
    conv_jets     = sparse_to_awkward(jets)    

else:
    predictions = model.predict(data['vbfhh']['Towers_NoPu'])

    from utils.image import tower_to_vector

    conv_jets   = sparse_to_awkward(tower_to_vector(predictions))
    
    from matplotlib.lines import Line2D
    
    legend_elem = [
        Line2D([0],[0], marker='o', color='w', label='Anti-kT jets', markeredgecolor='red' , markersize=10),
        Line2D([0],[0], marker='o', color='w', label='Pred. jets'  , markeredgecolor='blue', markersize=10),
    ]

In [None]:
# Event displays for a couple of events
import matplotlib.patches as mpatches

n_events = 2

extent = [-2.5, 2.5, -np.pi, np.pi]
for i in range(n_events):
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    
    # 1. Towers with PU
    axs[0].imshow(data['vbfhh']['Towers_PU'][i, ..., 0].T, origin='lower', norm=mpl.colors.LogNorm(), extent=extent, aspect='auto')
    axs[0].set_title('Towers with PU')
    axs[0].set_xlabel('eta')
    axs[0].set_ylabel('phi')
    
    # 2. Towers without PU
    axs[1].imshow(data['vbfhh']['Towers_NoPu'][i, ..., 0].T, origin='lower', norm=mpl.colors.LogNorm(), extent=extent, aspect='auto')
    axs[1].set_title('Towers without PU')
    axs[1].set_xlabel('eta')
    axs[1].set_ylabel('phi')
    
    # 3. Model Predictions
    if model_type == 'PU':
        axs[2].imshow(predictions[i, ..., 0].T, origin='lower', norm=mpl.colors.LogNorm(), extent=extent, aspect='auto')
        axs[2].set_title('Model Predictions')
        axs[2].set_xlabel('eta')
        axs[2].set_ylabel('phi')
        
    elif model_type == 'JetFinder':
        axs[2].imshow(data['vbfhh']['Towers_NoPu'][i, ..., 0].T, origin='lower', norm=mpl.colors.LogNorm(), extent=extent, aspect='auto')
        axs[2].set_title('Jet Seeds & GenJets')
        axs[2].set_xlabel('eta')
        axs[2].set_ylabel('phi')

        for jet in data['vbfhh']['Jet_NoPU'][i]:
            circle = mpatches.Circle((jet.eta, jet.phi), 0.4, color='red', fill=False)
            axs[2].add_patch(circle)
            axs[2].set_label('Offline jets')
            
        for seed in conv_jets[i]:
            circle = mpatches.Circle((seed.eta, seed.phi), 0.4, color='blue', fill=False)
            axs[2].add_patch(circle)
            axs[2].set_label('Pred. Jets')

        axs[2].legend(handles=legend_elem, loc='upper right')

        #for tjet in data['vbfhh']['GenJets'][i]:
        #    circle = mpatches.Circle((tjet.eta, tjet.phi), 0.4, color='yellow', fill=False)
        #    axs[2].add_patch(circle)
    
    plt.show()

### Look into the physics metrics (i.e. Trigger efficiency)

In [None]:
from utils.image import tower_to_vector


calib_jets = sort_and_pad(conv_jets,10)
truth_jets = sort_and_pad(awkward_to_vector(data['vbfhh']['GenJets']), 10)

bin_edges = np.logspace(np.log10(100), np.log10(800), 10)

mask = (calib_jets[...,3].rho > 0) & (truth_jets[...,3].rho > 0)

matched_calib_pt = calib_jets[mask].pt
matched_truth_pt = truth_jets[mask].pt

my_data = [(matched_calib_pt/matched_truth_pt)[(matched_truth_pt > min_pt) & (matched_truth_pt < max_pt)] for min_pt, max_pt in zip(bin_edges[:-1], bin_edges[1:])]
my_data = [np.array([0]) if len(d) == 0 else d for d in my_data]

violin_pt = plt.violinplot(
    my_data,
    positions   = 0.5*(bin_edges[:-1] + bin_edges[1:]),
    widths      = 0.8*np.diff(bin_edges),
    showmeans   = True,
    quantiles   = [[0.32, 0.68]]*len(my_data),
    showextrema = False,
    points       = 1000,
)

plt.hlines(1, np.min(bin_edges), np.max(bin_edges), colors='k', linestyles='dashed')
plt.ylim(0, 2)
plt.xlabel('Truth jet $p_T$ [GeV]')
plt.ylabel('Calibrated jet $p_T$ / Truth jet $p_T$')
plt.title('Calibration performance')
plt.grid()
plt.show()

In [None]:
# Look at the energy truth vs predicted

pred_energy_flat  = ak.flatten(matched_calib_pt).to_numpy()
truth_energy_flat = ak.flatten(matched_truth_pt).to_numpy()

fig, ax = plt.subplots(figsize=(8,6))

h = ax.hist2d(truth_energy_flat, pred_energy_flat, bins=50,
            range=[[100,800], [100,800]],
            cmap='viridis', norm=mpl.colors.LogNorm())

fig.colorbar(h[3], ax=ax, label='Counts')

x_min, x_max = ax.get_xlim()
ax.plot([x_min, x_max], [x_min, x_max], 'r--', linewidth=2, label='Tr(Et)=Pred(Et)')

ax.set_xlabel('Truth Jet Energy [GeV]')
ax.set_ylabel('Pred. Jet Energy [GeV]')
ax.set_title('Truth Matched Energy Composition')
ax.legend()
ax.grid(True, alpha=0.3)

plt.show()

In [None]:
# Get background samples to find cut for turn-on curve

from tqdm import tqdm

base_rate = 40e3

def get_rate(x, weights, threshold):
    return np.asarray(weights)[np.asarray(x) > threshold].sum() / weights.sum() * base_rate

def get_threshold(x, weights, target_rate, tol=1e-12, max_iter=100):
    if get_rate(x, weights, threshold=0) < target_rate:
        return 0.0
    
    lo, hi = np.min(x), np.max(x)
    for _ in tqdm(range(max_iter)):
        mid = 0.5 * (lo + hi)
        rate = get_rate(x, weights, threshold=mid)
        
        if abs(rate - target_rate) < tol:
            return mid
        elif rate > target_rate:
            lo = mid
        else:
            hi = mid

    return 0.5 * (lo + hi)

thresholds = {}

thresholds['ptj1'] = get_threshold(data['jz']['GenJets'][:,0].rho, np.ones(len(data['jz']['GenJets'][:,0].rho)), target_rate=25)
thresholds['ptj4'] = get_threshold(data['jz']['GenJets'][:,3].rho, np.ones(len(data['jz']['GenJets'][:,3].rho)), target_rate=100)    

print(thresholds)

In [None]:
from utils.image import plot_turn_on

fig, ax = plt.subplots(figsize=(5, 5))

matched_reco, matched_truth = truth_match(calib_jets, truth_jets, max_dR=0.4)
matched_truth = sort_and_pad(matched_truth, 10)

plot_turn_on(
    truth_jets[...,0].rho,
    matched_truth[...,0].rho>thresholds['ptj1'],
    weights = None,
    bins    = np.linspace(100, 300, 20),
    ax      = ax, 
    label   = 'CNN Calibrated images',
)