In [None]:
# imports

import os
import sys
import json
import joblib
import importlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

thisdir = os.getcwd()
topdir = os.path.abspath(os.path.join(thisdir, '../../../'))
sys.path.append(topdir)

import tools.iotools as iotools
import tools.dftools as dftools
import plotting.plottools as plottools
from tools.dataloadertools import MEDataLoader

from studies.clusters_2024.preprocessing.preprocessor import make_default_preprocessor
from studies.clusters_2024.nmf.modeldefs.nmf2d import NMF2D

In [None]:
# set path to files

# settings
datadir = '/eos/user/l/llambrec/dialstools-output'
year = '2024'
eras = {
    'A': ['v1'],
    'B': ['v1'],
    'C': ['v1'],
    'D': ['v1'],
    'E': ['v1', 'v2'],
    'F': ['v1'],
    'G': ['v1'],
    'H': ['v1'],
    'I': ['v1', 'v2'],
    'J': ['v1']
}
dataset = 'ZeroBias'
reco = 'PromptReco'
mebase = 'PixelPhase1-Phase1_MechanicalView-PXBarrel-clusters_per_SignedModuleCoord_per_SignedLadderCoord_PXLayer_{}'
layer = 1
me = mebase.format(layer)

# find files corresponding to settings
files = {}
for era, versions in eras.items():
    for version in versions:
        f = f'{dataset}-Run{year}{era}-{reco}-{version}-DQMIO-{me}.parquet'
        f = os.path.join(datadir, f)
        files[f'{era}-{version}'] = f

# existence check
missing = []
for f in files.values():
    if not os.path.exists(f):
        missing.append(f)
if len(missing) > 0:
    raise Exception(f'The following files do not exist: {missing}')
else:
    print(f'Found {len(files)} files.')

In [None]:
# define era to use in testing

era = 'B-v1'

In [None]:
# make preprocessors for the corresponding eras

preprocessor = make_default_preprocessor(era, layer)

In [None]:
# make dataloaders for the corresponding eras
# (note: could make one big dataloader for eras together,
#  but this makes correct preprocessing a bit harder as it depends on the era,
#  so keep dataloaders separate per era for now)

dataloader = MEDataLoader([files[era]])

In [None]:
# read NMF model

modelfile = os.path.join('models', f'PXLayer_{layer}', f'nmf_model_era{era}.pkl')
nmf = joblib.load(modelfile)

In [None]:
# plot model components

C = nmf.components
for idx in range(len(C)):
    fig, ax = plottools.plot_hist_2d(C[idx],
                   title=f'Component {idx+1}', titlesize=15,
                   xaxtitle=None, xaxtitlesize=None, yaxtitle=None, yaxtitlesize=None,
                   ticklabelsize=12, colorticklabelsize=12, extent=None, aspect=None,
                   docolorbar=True, caxtitle='Number of clusters\n(normalized)',
                   caxrange=(1e-6, 2),
                   caxtitlesize=15, caxtitleoffset=35,
                   origin='lower')
    title = me.split('-')[-1]
    ax.text(0.01, 1.3, title, fontsize=15, transform=ax.transAxes)
    plt.show()
    plt.close()

In [None]:
# loop over batches for testing

batch_size = 3000
thresholds = [0.02, 0.03, 0.05, 0.1, 0.2]
threshold_counts = {threshold: [] for threshold in thresholds}
run_numbers = []
ls_numbers = []

for batchidx, df in enumerate(dataloader.read_sequential_batches(batch_size=batch_size)):
        
    # filtering
    ndf = len(df)
    df = df[df['entries'] > 0.5e6]
    print(f'Found {len(df)} / {ndf} instances passing filters.')
    if len(df)==0: continue
        
    # do preprocessing
    print('Preprocessing...')
    mes_preprocessed = preprocessor.preprocess(df)
    
    # do evaluation
    mes_pred = nmf.predict(mes_preprocessed)
    losses = np.square(mes_preprocessed - mes_pred)
    
    # do time correction
    # (preliminary implementation)
    losses = np.multiply(np.multiply(losses[2:, :, :], losses[1:-1, :, :]), losses[0:-2, :, :])
    
    # count cells above thresholds
    run_numbers.append(df['run_number'].values[2:])
    ls_numbers.append(df['ls_number'].values[2:])
    for threshold in thresholds:
        counts = np.count_nonzero(losses>threshold, axis=(1,2))
        threshold_counts[threshold].append(counts)
        
threshold_counts = {threshold: np.concatenate(counts) for threshold, counts in threshold_counts.items()}
run_numbers = np.concatenate(run_numbers)
ls_numbers = np.concatenate(ls_numbers)

In [None]:
# make a plot of thresholds

fig, ax = plt.subplots(figsize=(8,6))
bins = np.linspace(0, 300, num=51)
cids = np.linspace(0, 1, num=len(threshold_counts))
cmap = plt.get_cmap('cool')
for idx, (threshold, counts) in enumerate(threshold_counts.items()):
    ax.hist(counts, bins=bins, density=True,
            histtype='step', linewidth=2,
            label=f'Threshold: {threshold}', color=cmap(cids[idx]))
ax.set_yscale('log')
ax.grid(which='both')
ax.set_xlabel('Number of bins above threshold', fontsize=15)
ax.set_ylabel('Number of lumisections (normalized)', fontsize=15)
ax.legend()
_ = ax.set_title(me.split('-')[-1], fontsize=15)

In [None]:
# plot some random examples

mask = ((threshold_counts[0.2] > 100) & (threshold_counts[0.2] > 100)).astype(bool)
print(f'Found {np.sum(mask.astype(int))} / {len(mask)} instances')
masked_run_numbers = run_numbers[mask]
masked_ls_numbers = ls_numbers[mask]

nplot = 0

if nplot > 0:
    
    # calculate random indices and load data
    print('Loading data...')
    random_ids = np.random.choice(len(masked_run_numbers), size=nplot, replace=False)
    masked_run_numbers = masked_run_numbers[random_ids]
    masked_ls_numbers = masked_ls_numbers[random_ids]
    df = iotools.read_lumisections(files[era], masked_run_numbers, masked_ls_numbers)
    mes, runs, lumis = dftools.get_mes(df, xbinscolumn='x_bin', ybinscolumn='y_bin', runcolumn='run_number', lumicolumn='ls_number')
    
    # preprocess and predict
    print('Processing...')
    mes_preprocessed = preprocessor.preprocess(df)
    mes_pred = nmf.predict(mes_preprocessed)
    losses = np.square(mes_preprocessed - mes_pred)

    print('Plotting...')
    for idx in range(nplot):
        run = runs[idx]
        lumi = lumis[idx]
        me_orig = mes[idx, :, :]
        me_preprocessed = mes_preprocessed[idx, :, :]
        me_pred = mes_pred[idx, :, :]
        loss = losses[idx, :, :]
    
        fig, axs = plt.subplots(ncols=4, figsize=(24, 6))
        fig, axs[0] = plottools.plot_hist_2d(me_orig, fig=fig, ax=axs[0],
                   title='Raw', titlesize=15,
                   xaxtitle=None, xaxtitlesize=None, yaxtitle=None, yaxtitlesize=None,
                   ticklabelsize=12, colorticklabelsize=12, extent=None, aspect=None,
                   docolorbar=True, caxtitle='Number of clusters',
                   caxtitlesize=15, caxtitleoffset=15,
                   origin='lower')
        fig, axs[1] = plottools.plot_hist_2d(me_preprocessed, fig=fig, ax=axs[1],
                   title='Input', titlesize=15,
                   xaxtitle=None, xaxtitlesize=None, yaxtitle=None, yaxtitlesize=None,
                   ticklabelsize=12, colorticklabelsize=12, extent=None, aspect=None,
                   docolorbar=True, caxtitle='Number of clusters\n(normalized)',
                   caxrange=(1e-6,2),
                   caxtitlesize=15, caxtitleoffset=30,
                   origin='lower')
        fig, axs[2] = plottools.plot_hist_2d(me_pred, fig=fig, ax=axs[2],
                   title='Reconstructed', titlesize=15,
                   xaxtitle=None, xaxtitlesize=None, yaxtitle=None, yaxtitlesize=None,
                   ticklabelsize=12, colorticklabelsize=12, extent=None, aspect=None,
                   docolorbar=True, caxtitle='Number of clusters\n(normalized)',
                   caxrange=(1e-6,2),
                   caxtitlesize=15, caxtitleoffset=30,
                   origin='lower')
        fig, axs[3] = plottools.plot_hist_2d(loss, fig=fig, ax=axs[3],
                   title='Loss', titlesize=15,
                   xaxtitle=None, xaxtitlesize=None, yaxtitle=None, yaxtitlesize=None,
                   ticklabelsize=12, colorticklabelsize=12, extent=None, aspect=None,
                   docolorbar=True, caxtitle='Loss',
                   caxrange=(0.0, 0.1),
                   caxtitlesize=15, caxtitleoffset=30,
                   origin='lower')
        plt.subplots_adjust(wspace=0.5)
        title = me.split('-')[-1] + f', Run {run}, LS {lumi}'
        axs[0].text(0.01, 1.3, title, fontsize=15, transform=axs[0].transAxes)
        plt.show()
        plt.close()

In [None]:
# plot an example with time correction

mask = ((threshold_counts[0.2] > 100) & (threshold_counts[0.2] > 100)).astype(bool)
print(f'Found {np.sum(mask.astype(int))} / {len(mask)} instances')
masked_run_numbers = run_numbers[mask]
masked_ls_numbers = ls_numbers[mask]

if True:
    
    # calculate random index and load data in range
    print('Loading data...')
    random_idx = np.random.choice(len(masked_run_numbers), size=1, replace=False)
    masked_run_number = masked_run_numbers[random_idx][0]
    masked_ls_number = masked_ls_numbers[random_idx][0]
    print(masked_run_number)
    print(masked_ls_number)
    if masked_ls_number < 2: raise Exception('Try again')
    masked_run_numbers = np.array([masked_run_number]*3)
    masked_ls_numbers = np.arange(masked_ls_number-len(masked_run_numbers)+1, masked_ls_number+1)
    print(masked_run_numbers)
    print(masked_ls_numbers)
    df = iotools.read_lumisections(files[era], masked_run_numbers, masked_ls_numbers)
    mes, runs, lumis = dftools.get_mes(df, xbinscolumn='x_bin', ybinscolumn='y_bin', runcolumn='run_number', lumicolumn='ls_number')
    
    # preprocess and predict
    print('Processing...')
    mes_preprocessed = preprocessor.preprocess(df)
    mes_pred = nmf.predict(mes_preprocessed)
    losses = np.square(mes_preprocessed - mes_pred)

    print('Plotting...')
    for idx in range(len(runs)):
        run = runs[idx]
        lumi = lumis[idx]
        me_orig = mes[idx, :, :]
        me_preprocessed = mes_preprocessed[idx, :, :]
        me_pred = mes_pred[idx, :, :]
        loss = losses[idx, :, :]
    
        fig, axs = plt.subplots(ncols=4, figsize=(24, 6))
        fig, axs[0] = plottools.plot_hist_2d(me_orig, fig=fig, ax=axs[0],
                   title='Raw', titlesize=15,
                   xaxtitle=None, xaxtitlesize=None, yaxtitle=None, yaxtitlesize=None,
                   ticklabelsize=12, colorticklabelsize=12, extent=None, aspect=None,
                   docolorbar=True, caxtitle='Number of clusters',
                   caxtitlesize=15, caxtitleoffset=15,
                   origin='lower')
        fig, axs[1] = plottools.plot_hist_2d(me_preprocessed, fig=fig, ax=axs[1],
                   title='Input', titlesize=15,
                   xaxtitle=None, xaxtitlesize=None, yaxtitle=None, yaxtitlesize=None,
                   ticklabelsize=12, colorticklabelsize=12, extent=None, aspect=None,
                   docolorbar=True, caxtitle='Number of clusters\n(normalized)',
                   caxrange=(1e-6,2),
                   caxtitlesize=15, caxtitleoffset=30,
                   origin='lower')
        fig, axs[2] = plottools.plot_hist_2d(me_pred, fig=fig, ax=axs[2],
                   title='Reconstructed', titlesize=15,
                   xaxtitle=None, xaxtitlesize=None, yaxtitle=None, yaxtitlesize=None,
                   ticklabelsize=12, colorticklabelsize=12, extent=None, aspect=None,
                   docolorbar=True, caxtitle='Number of clusters\n(normalized)',
                   caxrange=(1e-6,2),
                   caxtitlesize=15, caxtitleoffset=30,
                   origin='lower')
        fig, axs[3] = plottools.plot_hist_2d(loss, fig=fig, ax=axs[3],
                   title='Loss', titlesize=15,
                   xaxtitle=None, xaxtitlesize=None, yaxtitle=None, yaxtitlesize=None,
                   ticklabelsize=12, colorticklabelsize=12, extent=None, aspect=None,
                   docolorbar=True, caxtitle='Loss',
                   caxrange=(0.0, 0.1),
                   caxtitlesize=15, caxtitleoffset=30,
                   origin='lower')
        plt.subplots_adjust(wspace=0.5)
        title = me.split('-')[-1] + f', Run {run}, LS {lumi}'
        axs[0].text(0.01, 1.3, title, fontsize=15, transform=axs[0].transAxes)
        plt.show()
        plt.close()
        
    loss = np.prod(losses, axis=0)
    fig, axs[0] = plottools.plot_hist_2d(loss,
                   title='Loss with time correction', titlesize=15,
                   xaxtitle=None, xaxtitlesize=None, yaxtitle=None, yaxtitlesize=None,
                   ticklabelsize=12, colorticklabelsize=12, extent=None, aspect=None,
                   docolorbar=True, caxtitle='Loss',
                   caxrange=(0.0, 0.1),
                   caxtitlesize=15, caxtitleoffset=15,
                   origin='lower')
    title = me.split('-')[-1] + f', Run {runs[-1]}, LS {lumis[-1]}'
    ax.text(0.01, 1.3, title, fontsize=15, transform=ax.transAxes)
    plt.show()