In [None]:
import glob
import pandas
import matplotlib.pyplot as plt
import seaborn
import json
import numpy as np
import os

import mplhep as hep
hep.style.use("CMS")

In [None]:
losses = {
    "ParticleTransformer": {},
    "OmniParT": {},
    "OmniFeedforward": {},
    "LorentzNet": {},
    "SimpleDNN": {},
}

for path in sorted(glob.glob("../training-outputs/240812_3vars/*/*/*/history.json")):
    spl = path.split("/")
    model = spl[-2]
    train_size = float(spl[-4].split("_")[1])
    print(model, train_size)

    losses[model][train_size] = json.load(open(path))

In [None]:
best_val_losses = {model: [np.min(losses[model][k]["losses_validation"]) for k in losses[model].keys()] for model in losses.keys()}
train_fracs = {model: [k for k in losses[model].keys()] for model in losses.keys()}

In [None]:
# plt.plot(train_fracs["SimpleDNN"], best_val_losses["SimpleDNN"], marker="o", label="DeepSet")
# plt.plot(train_fracs["LorentzNet"], best_val_losses["LorentzNet"], marker="o", label="LorentzNet")
plt.plot(train_fracs["ParticleTransformer"], best_val_losses["ParticleTransformer"], marker="o", label="ParT")
plt.plot(train_fracs["OmniParT"], best_val_losses["OmniParT"], marker="o", label="OmniParT")
# plt.plot(train_fracs["OmniFeedforward"], best_val_losses["OmniFeedforward"], marker="o", label="OmniDeepSet")
plt.xscale("log")
plt.yscale("log")
plt.legend(loc="best")
plt.xlabel("train dataset fraction")
plt.ylabel("jet reg validation loss")
plt.title("240812_3vars")

In [None]:
ks = sorted(losses["ParticleTransformer"].keys())

In [None]:
fig, axs = plt.subplots(4, 4, figsize=(16, 16))
for ax, k in zip(axs.flatten(), ks):
    plt.sca(ax)
    plt.plot(losses["ParticleTransformer"][k]["losses_validation"], label="ParT")
    plt.plot(losses["OmniParT"][k]["losses_validation"], label="OmniParT")
    plt.legend(loc="best", fontsize=10)
    plt.title("frac={}".format(k), fontsize=10)
    plt.yscale("log")
    plt.ylabel("validation loss", fontsize=10)
    ax.tick_params(axis='both', which='major', labelsize=10)
    ax.tick_params(axis='both', which='minor', labelsize=8)
    plt.ylim(10**-4, 10**1)
    plt.xlabel("epoch", fontsize=10)
#plt.tight_layout()