In [None]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from manifoldembedder import *

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (6,6)
plt.rcParams.update({'font.size': 14})

In [None]:
train_file_list = []
val_file_list = []
test_file_list = []
predict_file_list = []

basedir = '/nobackup/users/sangeon/datasets/jetGridpack/jetpairs'



for file in os.listdir(basedir):
    if "train" in file:
        train_file_list.append(os.path.join(basedir,file))
    elif "val" in file:
        val_file_list.append(os.path.join(basedir,file))
    elif "predict" in file:
        predict_file_list.append(os.path.join(basedir,file))
    elif "test" in file:
        test_file_list.append(os.path.join(basedir,file))

In [None]:
file_dict = {'train':os.path.join(basedir, 'Gridpack_jettrain_16part_merged.pt'),
             'val':os.path.join(basedir, 'Gridpack_jetval_16part_merged.pt'),
             'test':os.path.join(basedir, 'Gridpack_jetpredict_16part_merged.pt'),
             'predict':os.path.join(basedir, 'Gridpack_jetpredict_16part_merged.pt')}

# Paper

In [None]:
jet_dm = JetDataModule(file_dict,2000)
model = ManifoldEmbedder("jets",2,"Transformer", 0.0005, [32, 4, 3, 2, 2, 16, 0.2,0.20, [1000,400,20]])

#For Hyperbolic Embedding
#model = HyperbolicEmbedder("jets",2,"Transformer", 0.00054607179632484, 1e-8, 1e-4, [32,4,3,2,2,16,0.25,0.25,[1000,500,20]])


In [None]:
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=5, verbose=False)
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath="/home/sangeon/ToyJetGenerator/training/simulatedtoyjets_checkpoints",
    filename="Transformer-paper-finaltuning-trywithp-{epoch:02d}-{val_loss:.2f}",
    save_top_k=4,
    mode="min",
)

In [None]:
trainer = Trainer(gpus=1,callbacks=[PrintCallbacks(),early_stop_callback,checkpoint_callback],auto_lr_find=True)

In [None]:
trainer.fit(model, jet_dm)

In [None]:
#model = ManifoldEmbedder.load_from_checkpoint("./simulatedtoyjets_checkpoints/Transformer-paper-finaltuning-try6-epoch=00-val_loss=0.28.ckpt")

In [None]:
model.eval()
a = trainer.predict(model, jet_dm)

In [None]:
label = np.array([])
embedding = np.empty((0,2))
for batch in a:
    embedding = np.vstack((embedding, batch[0].cpu().numpy()))
    label = np.concatenate([label, batch[1].cpu().numpy()])


In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (6,6)
plt.rcParams.update({'font.size': 14})

In [None]:
namelist = ['QCD','2p25','2p170','3p25','3p170','4p170','4p400']

In [None]:
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
for i in range(7):
    ax.scatter(embedding[label==i][:10000,0],embedding[label==i][:10000,1],s=10,alpha=0.10,label=namelist[i])

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_xlim([0,0.8])
ax.set_ylim([-0.9,0.25])
leg = ax.legend()
for lh in leg.legendHandles: 
    lh.set_alpha(1)
    
ax.set_title('Simulated Jet Embedding')    


In [None]:
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
for i in range(7):
    ax.scatter(embedding[label==i][:1000,0],embedding[label==i][:1000,1],s=10,alpha=0.20,label=namelist[i])

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_xlim([-0.4,0.0])
ax.set_ylim([0.12,0.22])
leg = ax.legend()
for lh in leg.legendHandles: 
    lh.set_alpha(1)
    
ax.set_title('Simulated Jet Embedding')    

In [None]:
from scipy import stats

def plot_kde(ax, whichlabel, color):

    xmin, xmax = [-0.5,0.3]
    ymin, ymax = [0,0.35]
    X, Y = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
    positions = np.vstack([X.ravel(), Y.ravel()])
    kernel = stats.gaussian_kde(embedding[label==whichlabel][:10000].T)
    Z = np.reshape(kernel(positions).T, X.shape)
    Z /= Z.max()
    CS = ax.contour(X,Y,Z,levels=[0.3,0.6],colors=color, alpha=0.8)

    return CS


In [None]:
fig = plt.figure()
ax = fig.add_subplot(1,1,1)

lines = []
for i in range(7):
    CS = plot_kde(ax, i, f'C{i}')
    lines.append(CS.collections[0])


ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_xlim([-0.4,0.0])
ax.set_ylim([0.12,0.21])
ax.legend(lines, namelist)
ax.set_title('Simulated Jet Embedding')    
