In [None]:
import os
import sys
import time
import numpy as np
import matplotlib.pyplot as plt
import importlib

import torch
import torch.optim as optim
import torch.nn.functional as F
import ecalendcapmodel
importlib.reload(ecalendcapmodel)
from ecalendcapmodel import ResNetAE, ResNetAEPixel, training_loop

# framework modules
sys.path.append('../')
import plotting.plottools
importlib.reload(plotting.plottools)
from plotting.plottools import plot_histogram
import training.prepare_training_set
importlib.reload(training.prepare_training_set)
from training.prepare_training_set import prepare_training_data_from_files

In [None]:
# syntax check on (20x20) images as originally used by ECAL

ae = ResNetAE(1, 3, [16, 32], debug=True)
x = torch.tensor(np.ones((1,1,20,20)).astype(np.float32))
_ = ae(x)

In [None]:
# syntax check on (32x32) images as used here

ae = ResNetAEPixel(1, 3, [16, 32], debug=True)
x = torch.tensor(np.ones((1,1,32,32)).astype(np.float32))
_ = ae(x)

In [None]:
# load some example data

file = '../data/data/ZeroBias-Run2023C-PromptReco-v1-DQMIO-PixelPhase1-Tracks-PXForward-clusterposition_xy_ontrack_PXDisk_+1_preprocessed.parquet'

kwargs = ({
    'verbose': True,
    'entries_threshold': 10000,
    'skip_first_lumisections': 5,
    'veto_patterns': [np.zeros((2,2)), np.zeros((3,1)), np.zeros((1,3))]
})
(train_data, training_runs, training_lumis) = prepare_training_data_from_files([file], **kwargs)

In [None]:
# convert 32x32 to 20x20

#train_data = train_data[:,:20,:20,:]

In [None]:
# limit number of training instances

train_data = train_data[:1000, :, :, :]

In [None]:
# convert to pytorch tensor

train_data_tensor = np.expand_dims(train_data, axis=1)[:,:,:,:,0]
train_data_tensor = train_data_tensor.astype(np.float32)
train_data_tensor = torch.tensor(train_data_tensor)
print(train_data_tensor.size())

In [None]:
# training loop

ae = ResNetAEPixel(1, 1, [16, 32])
optimizer = optim.Adam(ae.parameters(), lr=5e-4)

epochs = 1
batch_size = 50

training_loop(ae, train_data_tensor, optimizer, epochs=epochs, batch_size=batch_size)

In [None]:
# plot examples

nplots = 5
plotids = np.random.choice(len(train_data), size=nplots)

for i in plotids:
    orig = train_data[i,:,:,0]
    reco = np.expand_dims(np.expand_dims(orig, axis=0), axis=0)
    reco = ae(torch.tensor(reco.astype(np.float32)))
    reco = reco[0,0,:,:]
    reco = reco.detach().numpy()
    fig,axs = plt.subplots(figsize=(12,6), ncols=2)
    plot_histogram(orig, fig=fig, ax=axs[0])
    plot_histogram(reco, fig=fig, ax=axs[1])
    axs[0].text(0.02, 1.02, 'Run: {}, lumi: {}'.format(training_runs[i], training_lumis[i]), transform=axs[0].transAxes, fontsize=12)