In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

In [None]:
dm_labels = [
    r"$h^{\pm}$",
    r"$h^{\pm}\pi^0$",
    r'$h^\pm+\geq2\pi^0$',
    r"$h^{\pm}h^{\mp}h^{\pm}$",
    r"$h^{\pm}h^{\mp}h^{\pm}+\geq\pi^0$",
    "Rare",
]

```
#!/bin/bash

rm -Rf training-outputs/*

#From scratch
singularity exec --env PYTHONPATH=`pwd`:`pwd`/enreg/omnijet_alpha/gabbro ~/HEP-KBFI/singularity/pytorch.simg.2 python3 enreg/scripts/trainModel.py \
    output_dir=training-outputs/from_scratch trainSize=1e5 fraction_valid=0.1 training.num_epochs=2 model_type=OmniParT \
    training_type=dm_multiclass models.OmniParT.version=from_scratch
mv batch_*.pkl embeds/from_scratch/

#Fixed backbone
singularity exec --env PYTHONPATH=`pwd`:`pwd`/enreg/omnijet_alpha/gabbro ~/HEP-KBFI/singularity/pytorch.simg.2 python3 enreg/scripts/trainModel.py \
    output_dir=training-outputs/fixed_backbone trainSize=1e5 fraction_valid=0.1 training.num_epochs=2 model_type=OmniParT \
    training_type=dm_multiclass models.OmniParT.version=fixed_backbone models.OmniParT.bb_path=weights/OmniJet_generative_model_FiduciaryCagoule_254.ckpt
mv batch_*.pkl embeds/fixed_backbone/

singularity exec --env PYTHONPATH=`pwd`:`pwd`/enreg/omnijet_alpha/gabbro ~/HEP-KBFI/singularity/pytorch.simg.2 python3 enreg/scripts/trainModel.py \
    output_dir=training-outputs/finetuned_finetuned_binary_classification trainSize=1e5 fraction_valid=0.1 training.num_epochs=2 model_type=OmniParT \
    training_type=dm_multiclass models.OmniParT.version=fixed_backbone \
    models.OmniParT.bb_path=weights/trainfrac_1e6/binary_classification/OmniParT_fine_tuning/model_best.pt
mv batch_*.pkl embeds/finetuned_binary_classification/

singularity exec --env PYTHONPATH=`pwd`:`pwd`/enreg/omnijet_alpha/gabbro ~/HEP-KBFI/singularity/pytorch.simg.2 python3 enreg/scripts/trainModel.py \
    output_dir=training-outputs/finetuned_dm_multiclass trainSize=1e5 fraction_valid=0.1 training.num_epochs=2 model_type=OmniParT \
    training_type=dm_multiclass models.OmniParT.version=fixed_backbone \
    models.OmniParT.bb_path=weights/trainfrac_1e6/dm_multiclass/OmniParT_fine_tuning/model_best.pt
mv batch_*.pkl embeds/finetuned_dm_multiclass/

singularity exec --env PYTHONPATH=`pwd`:`pwd`/enreg/omnijet_alpha/gabbro ~/HEP-KBFI/singularity/pytorch.simg.2 python3 enreg/scripts/trainModel.py \
    output_dir=training-outputs/finetuned_jet_regression trainSize=1e5 fraction_valid=0.1 training.num_epochs=2 model_type=OmniParT \
    training_type=dm_multiclass models.OmniParT.version=fixed_backbone \
    models.OmniParT.bb_path=weights/trainfrac_1e6/jet_regression/OmniParT_fine_tuning/model_best.pt
mv batch_*.pkl embeds/finetuned_jet_regression/
```

In [None]:
d1 = pickle.load(open("../embeds/from_scratch/batch_0.pkl", "rb"))
d2 = pickle.load(open("../embeds/fixed_backbone/batch_0.pkl", "rb"))
d3 = pickle.load(open("../embeds/finetuned_dm_multiclass/batch_0.pkl", "rb"))
d4 = pickle.load(open("../embeds/finetuned_jet_regression/batch_0.pkl", "rb"))
d5 = pickle.load(open("../embeds/finetuned_binary_classification/batch_0.pkl", "rb"))

In [None]:
d1e = d1["embeds"]
d2e = d2["embeds"]
d3e = d3["embeds"]
d4e = d4["embeds"]
d5e = d5["embeds"]

In [None]:
d1e = d1e.reshape((d1e.shape[0], d1e.shape[1]*d1e.shape[2]))
d2e = d2e.reshape((d2e.shape[0], d2e.shape[1]*d2e.shape[2]))
d3e = d3e.reshape((d3e.shape[0], d3e.shape[1]*d3e.shape[2]))
d4e = d4e.reshape((d4e.shape[0], d4e.shape[1]*d4e.shape[2]))
d5e = d5e.reshape((d5e.shape[0], d5e.shape[1]*d5e.shape[2]))

In [None]:
d1e.shape, d2e.shape, d3e.shape, d4e.shape, d5e.shape

In [None]:
plt.figure(figsize=(4,4))
b = np.linspace(-5, 5, 500)
plt.hist(d1e.flatten(), bins=b, histtype="step");
plt.hist(d2e.flatten(), bins=b, histtype="step");
plt.hist(d3e.flatten(), bins=b, histtype="step");
plt.hist(d4e.flatten(), bins=b, histtype="step");
plt.hist(d5e.flatten(), bins=b, histtype="step");
plt.yscale('log')

In [None]:
plt.figure(figsize=(5, 5))
plt.scatter(d1e[:, 0], d1e[:, 1], marker=".", s=d1["reco_jet_pt"]/5, alpha=0.5)
plt.scatter(d2e[:, 0], d2e[:, 1], marker=".", s=d2["reco_jet_pt"]/5, alpha=0.5)
plt.scatter(d3e[:, 0], d3e[:, 1], marker=".", s=d3["reco_jet_pt"]/5, alpha=0.5)
plt.scatter(d4e[:, 0], d4e[:, 1], marker=".", s=d4["reco_jet_pt"]/5, alpha=0.5)
plt.scatter(d5e[:, 0], d5e[:, 1], marker=".", s=d5["reco_jet_pt"]/5, alpha=0.5)

In [None]:
d3["binary_classification"]

In [None]:
params = {"random_state": 1, "verbose": 0, "perplexity": 40}

lim = 90
msk = d1["binary_classification"]==1
tsne = TSNE(**params)
tsne_output1 = tsne.fit_transform(d1e[msk])

tsne = TSNE(**params)
tsne_output2 = tsne.fit_transform(d2e[msk])

tsne = TSNE(**params)
tsne_output3 = tsne.fit_transform(d3e[msk])

tsne = TSNE(**params)
tsne_output4 = tsne.fit_transform(d4e[msk])

tsne = TSNE(**params)
tsne_output5 = tsne.fit_transform(d5e[msk])

In [None]:
plt.figure(figsize=(5, 5))

colors = plt.cm.Paired(np.arange(6))
for dm in np.unique(d3["dm_multiclass"][msk]):
    dm_msk = d3["dm_multiclass"][msk] == dm
    plt.scatter(tsne_output1[:, 0][dm_msk], tsne_output1[:, 1][dm_msk],
                marker=".",
                s=d3["reco_jet_pt"][msk][dm_msk]/2,
                alpha=0.5,
                color=colors[dm], label=dm_labels[dm])
plt.title("uninitialized")
plt.xlim(-lim, lim)
plt.ylim(-lim, lim)
plt.xticks([], [])
plt.yticks([], [])
plt.legend(ncols=2, loc=1)
plt.savefig("tsne_uninitialized.pdf", bbox_inches="tight")

In [None]:
plt.figure(figsize=(5, 5))

colors = plt.cm.Paired(np.arange(6))
for dm in np.unique(d3["dm_multiclass"][msk]):
    dm_msk = d3["dm_multiclass"][msk] == dm
    plt.scatter(tsne_output2[:, 0][dm_msk], tsne_output2[:, 1][dm_msk],
                marker=".",
                s=d3["reco_jet_pt"][msk][dm_msk]/2,
                alpha=0.5,
                color=colors[dm], label=dm_labels[dm])
plt.title("pre-trained")
plt.xlim(-lim, lim)
plt.ylim(-lim, lim)
plt.xticks([], [])
plt.yticks([], [])
plt.legend(ncols=2, loc=1)
plt.savefig("tsne_pretrained.pdf", bbox_inches="tight")

In [None]:
plt.figure(figsize=(5, 5))

colors = plt.cm.Paired(np.arange(6))
for dm in np.unique(d3["dm_multiclass"][msk]):
    dm_msk = d3["dm_multiclass"][msk] == dm
    plt.scatter(-tsne_output3[:, 0][dm_msk], -tsne_output3[:, 1][dm_msk],
                marker=".",
                s=d3["reco_jet_pt"][msk][dm_msk]/2,
                alpha=0.5,
                color=colors[dm], label=dm_labels[dm])
plt.title("fine-tuned, decay mode")
plt.xlim(-lim, lim)
plt.ylim(-lim, lim)
plt.xticks([], [])
plt.yticks([], [])
plt.legend(ncols=2, loc=1)
plt.savefig("tsne_finetuned.pdf", bbox_inches="tight")