In [None]:
# import external modules
import os
import sys
import json
import joblib
import numpy as np
import matplotlib.pyplot as plt
import importlib

# import PixelNMF
thisdir = os.getcwd()
mlserverdir = os.path.join(thisdir, 'mlserver-model')
sys.path.append(mlserverdir)
import pixelnmf
importlib.reload(pixelnmf)
from pixelnmf import PixelNMF

# import other tools (local)
topdir = os.path.abspath(os.path.join(thisdir, '../../../..'))
sys.path.append(topdir)
import tools.omstools as omstools
from tools.iotools import read_parquet
from tools.dftools import get_mes, filter_dfs
from studies.pixel_clusters_2024.plotting.plot_cluster_occupancy import plot_cluster_occupancy

In [None]:
# set monitoring element names
# todo: better distinction between names and short tags
# update: skip BPix1 because too much degraded (since 2025F).

#menames = ['BPix1', 'BPix2', 'BPix3', 'BPix4']
menames = ['BPix2', 'BPix3', 'BPix4']

In [None]:
# load a PixelNMF instance

pnmf = joblib.load('mlserver-model/pixelnmf.joblib')

In [None]:
# settings for loading example data
datadir = '/eos/user/l/llambrec/dialstools-output'
dataset = 'ZeroBias'
reco = 'PromptReco'
era = '2025F-v1'
mebase = 'PixelPhase1-Phase1_MechanicalView-PXBarrel-clusters_per_SignedModuleCoord_per_SignedLadderCoord_PXLayer_{}'

# find files corresponding to settings
files = {}
mainera, version = era.split('-', 1)
for mename in menames:
    layer = mename[-1]
    me = mebase.format(layer)
    f = f'{dataset}-Run{mainera}-{reco}-{version}-DQMIO-{me}.parquet'
    f = os.path.join(datadir, f)
    files[mename] = f

In [None]:
# print which runs are in which batches
# (so a convenient batch can be chosen to load)

f = files[menames[0]]
batch_size = 1000
temp = read_parquet(f, verbose=False, columns=['run_number', 'ls_number'])
runs = temp['run_number'].values
lumis = temp['ls_number'].values
array_idx = 0
batch_idx = 0
while array_idx < len(runs):
    first = array_idx
    last = min(len(runs), array_idx+batch_size) - 1
    print(f'Batch {batch_idx}: Run {runs[first]} (LS {lumis[first]}) - Run {runs[last]} (LS {lumis[last]})')
    array_idx += batch_size
    batch_idx += 1

In [None]:
# load data

batch_idx = 10
X = {}
for mename in menames:
    X[mename] = read_parquet(files[mename], verbose=False, batch_size=batch_size, batch_ids=[batch_idx])
    
# print run numbers
run_numbers = X[menames[0]]['run_number'].values
ls_numbers = X[menames[0]]['ls_number'].values
runs = np.unique(run_numbers)
print(runs)

In [None]:
# extract np arrays from dataframes

X_data = {}
for mename in menames:
    mes, _, _ = get_mes(X[mename],
                    xbinscolumn='x_bin', ybinscolumn='y_bin',
                    runcolumn='run_number', lumicolumn='ls_number')
    X_data[mename] = mes

In [None]:
# define filter info

# OMS attribute filters
oms_info_file = f'/eos/user/l/llambrec/pixelae/studies/pixel_clusters_2024/omsdata/omsdata_Run{era}.json'
with open(oms_info_file, 'r') as f:
    oms_info = json.load(f)
oms_attrs = [
    "beams_stable",
    "cms_active",
    "bpix_ready",
    "fpix_ready",
    "tibtid_ready",
    "tob_ready",
    "tecp_ready",
    "tecm_ready",
    "pileup"
]
oms_info_new = {}
for key, val in oms_info.items():
    if key not in oms_attrs: continue
    oms_info_new['oms__' + key] = omstools.find_oms_attr_for_lumisections(run_numbers, ls_numbers, oms_info, key)
oms_info = oms_info_new

# HLT rate filter
hltrate_info_file = f'/eos/user/l/llambrec/pixelae/studies/pixel_clusters_2024/omsdata/hltrate_Run{era}.json'
with open(hltrate_info_file, 'r') as f:
    hltrate_info = json.load(f)
hltrate_attrs = [
    "HLT_ZeroBias_v*"
]
hltrate_info = {'oms__hlt_zerobias_rate': omstools.find_hlt_rate_for_lumisections(run_numbers, ls_numbers, hltrate_info, hltrate_attrs[0])}

# add filter info to input data
filter_info = {**oms_info, **hltrate_info}
filter_info['oms__run_number'] = run_numbers
filter_info['oms__lumisection_number'] = ls_numbers
print('Filter info:')
for key, val in filter_info.items():
    print(f'  - {key}: {val.shape}')
X_data.update(filter_info)
print('Input data keys:')
print(X_data.keys())

In [None]:
# run the model

flags = pnmf.predict(X_data, verbose=True)
print(len(flags))
print(np.sum(flags.astype(int)))

In [None]:
# run the model step by step and plot intermediate outputs
# (for debugging)

# select small data (single instance)
X_small = {mename: X[mename].iloc[500:501] for mename in menames}
meidx = 0
print(X_small[menames[meidx]])

# convert from dataframe to np arrays
X_small_data = {}
for mename in menames:
    mes, _, _ = get_mes(X_small[mename],
                    xbinscolumn='x_bin', ybinscolumn='y_bin',
                    runcolumn='run_number', lumicolumn='ls_number')
    X_small_data[mename] = mes

mes_preprocessed = pnmf.preprocess(X_small_data)
mes_reco = pnmf.infer(mes_preprocessed)
losses = pnmf.loss(mes_preprocessed, mes_reco, do_thresholding=False)
losses_binary = pnmf.loss(mes_preprocessed, mes_reco, do_thresholding=True)
losses_combined = pnmf.combine(losses_binary, do_masking=False, do_thresholding=False)
losses_combined_binary = pnmf.combine(losses_binary, do_masking=True, do_thresholding=True)

plot_cluster_occupancy(mes_preprocessed[menames[meidx]][0],
                   title='Input', titlesize=15,
                   xaxtitlesize=15, yaxtitlesize=15,
                   ticklabelsize=12, colorticklabelsize=12,
                   docolorbar=True, caxtitle='Number of clusters\n(normalized)',
                   caxrange=(1e-6,2),
                   caxtitlesize=15, caxtitleoffset=30)
plot_cluster_occupancy(mes_reco[menames[meidx]][0],
                   title='Reconstructed', titlesize=15,
                   xaxtitlesize=15, yaxtitlesize=15,
                   ticklabelsize=12, colorticklabelsize=12,
                   docolorbar=True, caxtitle='Number of clusters\n(normalized)',
                   caxrange=(1e-6,2),
                   caxtitlesize=15, caxtitleoffset=30)
plot_cluster_occupancy(losses[menames[meidx]][0],
                   title='Loss', titlesize=15,
                   xaxtitlesize=15, yaxtitlesize=15,
                   ticklabelsize=12, colorticklabelsize=12,
                   docolorbar=True, caxtitle='Loss',
                   caxrange=(0, 0.1),
                   caxtitlesize=15, caxtitleoffset=30)
plot_cluster_occupancy(losses_binary[menames[meidx]][0],
                   title='Loss (binary)', titlesize=15,
                   xaxtitlesize=15, yaxtitlesize=15,
                   ticklabelsize=12, colorticklabelsize=12,
                   docolorbar=True, caxtitle='Loss',
                   caxrange=(0, 1),
                   caxtitlesize=15, caxtitleoffset=30)
plot_cluster_occupancy(losses_combined[0],
                   title='Combined loss (before masking)', titlesize=15,
                   xaxtitlesize=15, yaxtitlesize=15,
                   ticklabelsize=12, colorticklabelsize=12,
                   docolorbar=True, caxtitle='Loss',
                   caxrange=(0, 4),
                   caxtitlesize=15, caxtitleoffset=30)
plot_cluster_occupancy(pnmf.loss_mask[0],
                   title='Combined loss mask', titlesize=15,
                   xaxtitlesize=15, yaxtitlesize=15,
                   ticklabelsize=12, colorticklabelsize=12,
                   docolorbar=True, caxtitle='Value',
                   caxrange=(0, 4),
                   caxtitlesize=15, caxtitleoffset=30)
plot_cluster_occupancy(losses_combined_binary[0],
                   title='Combined loss (after masking)', titlesize=15,
                   xaxtitlesize=15, yaxtitlesize=15,
                   ticklabelsize=12, colorticklabelsize=12,
                   docolorbar=True, caxtitle='Loss',
                   caxrange=(0, 1),
                   caxtitlesize=15, caxtitleoffset=30)

In [None]:
# extra: store some data in a format usable for the test_predictions script

dosave = True

if dosave:

    import pickle

    # add the run and lumisection numbers to the data to be stored,
    # and remove OMS info (will be retrieved on the fly by DIALS)
    X_data_towrite = {}
    for key, val in X_data.items():
        if key.startswith('oms__'): continue
        X_data_towrite[key] = val
    X_data_towrite['run_number'] = run_numbers
    X_data_towrite['ls_number'] = ls_numbers

    # printouts for checking
    print('Will write following arrays:')
    for key, val in X_data_towrite.items():
        print(f'  - {key}: {val.shape}')

    with open('test_data.pkl', 'wb') as f:
        pickle.dump(X_data_towrite, f)