In [None]:
# imports

import os
import sys
import json
import joblib
import importlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

thisdir = os.getcwd()
topdir = os.path.abspath(os.path.join(thisdir, '../../../'))
sys.path.append(topdir)

import tools.iotools as iotools
import tools.dftools as dftools
import tools.patternfiltering as patternfiltering
import tools.rebinning as rebinning
import plotting.plottools as plottools
from tools.dataloadertools import MEDataLoader

from studies.clusters_2024.preprocessing.preprocessor import make_default_preprocessor
from studies.clusters_2024.preprocessing.preprocessor import PreProcessor
from studies.clusters_2024.nmf.modeldefs.nmf2d import NMF2D
from studies.clusters_2024.nmf.nmf_training import find_files
from studies.clusters_2024.nmf.nmf_testing_pattern import run_evaluation

In [None]:
# set path to files

layers = [1, 2]
input_files = {layer: find_files(layer) for layer in layers}

In [None]:
# define runs to use in training

era = 'F-v1'
dftemp = iotools.read_parquet(input_files[layers[0]][era], columns=['run_number', 'entries'])
dftemp = dftemp[dftemp['entries'] > 0.5e6]
available_runs = np.unique(dftemp['run_number'].values)
print('Available runs:')
print(available_runs)

training_runs = [382649, 382650]
print('Chosen training runs:')
print(training_runs)

# check
for training_run in training_runs:
    if training_run not in available_runs:
        raise Exception(f'Run {training_run} not in available runs.')

In [None]:
# make preprocessors for the corresponding era

preprocessors = {}
preprocessor_era = era
if '-part' in preprocessor_era: preprocessor_era = era.split('-part')[0]
for layer in layers:
    preprocessors[layer] = make_default_preprocessor(preprocessor_era, layer)

In [None]:
# load training data

dfs_training = {}
for layer in layers:
    print(f'Loading training data for layer {layer}...')
    dfs_training[layer] = iotools.read_runs(input_files[layer][era], training_runs, verbose=True)
ndf = len(dfs_training[layers[0]])
print(f'Found {ndf} instances.')

In [None]:
# do training

nmfs = {}
batch_size = 300
nbatches = 10

do_plot_components = True

# loop over layers
for layer in layers:
    print(f'Now running on layer {layer}...')
    print(f'Will train on {nbatches} batches of size {batch_size}.')
    
    # make the NMF model for this layer
    nmf = NMF2D(n_components=5, forget_factor=1, batch_size=batch_size, verbose=True,
                tol=0.0, max_no_improvement=100, max_iter=1000,
                alpha_H=0.1)
    
    # load the data
    df = dfs_training[layer]
    
    # filtering
    df = df[df['entries'] > 0.5e6/layer]
    print(f'  Found {len(df)} / {ndf} instances passing filters.')
    if len(df)==0: continue
        
    # preprocessing
    mes_preprocessed = preprocessors[layer].preprocess(df)
        
    # experimental: set zero-occupancy to 1 (average expected value after preprocessing)
    mes_preprocessed[mes_preprocessed==0] = 1
    
    # loop over random batches
    for batchidx in range(nbatches):
        print(f'Now processing batch {batchidx+1} / {nbatches}...')

        # make random indices
        random_ids = np.random.choice(np.arange(len(mes_preprocessed)), size=batch_size, replace=False)
        batch = mes_preprocessed[random_ids, :, :]

        # fit NMF
        nmf.fit(batch)
        
    nmfs[layer] = nmf
        
    # plot components
    if do_plot_components:
        C = nmf.components
        for idx in range(len(C)):
            fig, ax = plottools.plot_hist_2d(C[idx],
                   title=f'Component {idx+1}', titlesize=15,
                   xaxtitle=None, xaxtitlesize=None, yaxtitle=None, yaxtitlesize=None,
                   ticklabelsize=12, colorticklabelsize=12, extent=None, aspect=None,
                   docolorbar=True, caxtitle='Number of clusters\n(normalized)',
                   caxrange=(1e-6, 2),
                   caxtitlesize=15, caxtitleoffset=35,
                   origin='lower')
        plt.show()
        plt.close()

In [None]:
# save the models

dosave = False

if dosave:
    outputdir = f'models/PXLayer_{layer}'
    if not os.path.exists(outputdir): os.makedirs(outputdir)

    for era in eras:
        outputfile = os.path.join(outputdir, f'nmf_model_era{era}.pkl')
        joblib.dump(nmfs[era], outputfile)

In [None]:
# define runs/lumisections for testing

# for now, just pick one or multiple runs, later implement selection of specific LS range

testing_runs = [382654]

In [None]:
# settings for automasking

do_automasking = False

In [None]:
# settings for loss masking

do_loss_masking = True

if do_loss_masking:
    loss_mask_era = era
    if '-part' in era: loss_mask_era = era.split('-part')[0]
    loss_masks = {}
    loss_mask_preprocessors = {}
    for layer in layers:
        zerofrac_file = f'../preprocessing/normdata/zerofrac_Run2024{loss_mask_era}_PXLayer_{layer}.npy'
        zerofrac = np.load(zerofrac_file)
        loss_mask = (zerofrac < 0.9)
        loss_masks[layer] = loss_mask
        loss_mask_preprocessors[layer] = PreProcessor(f'PXLayer_{layer}')

In [None]:
# other evaluation settings

threshold = 0.1
flag_patterns = [np.ones((1,4))]

# cleaning
do_per_layer_cleaning = True
cleaning_patterns = [np.ones((2,8))]
cleaning_threshold = 1.5

In [None]:
# load the testing data

dfs_testing = {}
for layer in layers:
    print(f'Loading testing data for layer {layer}...')
    dfs_testing[layer] = iotools.read_runs(input_files[layer][era], testing_runs, verbose=True)
ndf = len(dfs_testing[layers[0]])
print(f'Found {ndf} instances.')

In [None]:
# process the testing data

flagged_run_numbers, flagged_ls_numbers = run_evaluation(dfs_testing, nmfs,
                     preprocessors = preprocessors,
                     threshold = threshold,
                     flag_patterns = flag_patterns,
                     do_per_layer_cleaning = do_per_layer_cleaning,
                     cleaning_patterns = cleaning_patterns,
                     cleaning_threshold = cleaning_threshold,
                     do_automasking = False,
                     automask_reader = None,
                     automask_map_preprocessors = None,
                     do_loss_masking = do_loss_masking,
                     loss_masks = loss_masks,
                     loss_mask_preprocessors = loss_mask_preprocessors)

In [None]:
# printouts
print(f'Found {len(flagged_run_numbers)} flagged lumisections:')
for run_number, ls_number in zip(flagged_run_numbers, flagged_ls_numbers):
    print(f'  - Run {run_number}, LS {ls_number}')

In [None]:
# plot some random (or not random) examples

# general settings
do_extended_loss_plots = True
do_combined_loss_plot = True

# random lumisections
nplot = 3
#random_ids = np.random.choice(len(available_run_numbers), size=min(nplot, len(available_run_numbers)), replace=False)
#selected_run_numbers = available_run_numbers[random_ids]
#selected_ls_numbers = available_ls_numbers[random_ids]
random_ids = np.random.choice(len(flagged_run_numbers), size=min(nplot, len(flagged_run_numbers)), replace=False)
selected_run_numbers = flagged_run_numbers[random_ids]
selected_ls_numbers = flagged_ls_numbers[random_ids]

# alternative: specific selected lumisections
#selected_runlumis = [(385443, 1566), (385443, 1578), (385443, 1579), (385443, 1592)]
#selected_run_numbers = [el[0] for el in selected_runlumis]
#selected_ls_numbers = [el[1] for el in selected_runlumis]

if len(selected_run_numbers) > 0:
    
    # calculate random indices and load data
    print('Loading data...')
    dfs = {}
    mes = {}
    for layer in layers:
        dfs[layer] = iotools.read_lumisections(input_files[layer][era], selected_run_numbers, selected_ls_numbers)
        mes[layer], runs, lumis = dftools.get_mes(dfs[layer], xbinscolumn='x_bin', ybinscolumn='y_bin', runcolumn='run_number', lumicolumn='ls_number')
    
    # preprocess and predict
    print('Processing...')
    mes_preprocessed = {}
    mes_pred = {}
    losses = {}
    losses_binary = {}
    for layer in layers:
        mes_preprocessed[layer] = preprocessors[layer].preprocess(dfs[layer])
        mes_pred[layer] = nmfs[layer].predict(mes_preprocessed[layer])
        losses[layer] = np.square(mes_preprocessed[layer] - mes_pred[layer])
        losses_binary[layer] = (losses[layer] > threshold).astype(int)
    
    # automasking
    if do_automasking:
        print('Applying automasks...')
        for layer in layers:
            subsystem = f'BPix{layer}'
            automask_maps = automask_reader.get_automask_maps_for_ls(selected_run_numbers, selected_ls_numbers, subsystem, invert=True)
            automask_maps = automask_map_preprocessors[layer].preprocess_mes(automask_maps, None, None)
            losses[layer] = np.multiply(losses[layer], automask_maps)
            losses_binary[layer] = np.multiply(losses_binary[layer], automask_maps)
            
    # manual masking
    if do_loss_masking:
        print('Applying loss mask...')
        for layer in layers:
            mask = loss_masks[layer]
            mask = np.expand_dims(mask, 0)
            mask = loss_mask_preprocessors[layer].preprocess_mes(mask, None, None)
            losses[layer] = np.multiply(losses[layer], mask)
            losses_binary[layer] = np.multiply(losses_binary[layer], mask)
            
    # cleaning
    if do_per_layer_cleaning:
        print('Cleaning loss maps')
        losses_binary_cleaned = {}
        for layer in layers:
            losses_binary_cleaned[layer] = patternfiltering.filter_any_pattern(losses_binary[layer], cleaning_patterns, threshold=cleaning_threshold)
    
    # make rebinned and overlayed binary loss map
    target_shape = losses[layers[0]].shape[1:3]
    losses_binary_rebinned = {}
    losses_binary_combined = np.zeros(losses[layers[0]].shape)
    for layer in layers:
        source = losses_binary[layer]
        if do_per_layer_cleaning: source = losses_binary_cleaned[layer]
        losses_binary_rebinned[layer] = rebinning.rebin_keep_clip(source, target_shape, 1, mode='cv2')
        losses_binary_combined += losses_binary_rebinned[layer]
    losses_binary_combined = (losses_binary_combined >= 2).astype(int)
        
    # make the plots
    print('Plotting...')
    for idx in range(len(selected_run_numbers)):
        run = runs[idx]
        lumi = lumis[idx]
        for layer in layers:
            me_orig = mes[layer][idx, :, :]
            me_preprocessed = mes_preprocessed[layer][idx, :, :]
            me_pred = mes_pred[layer][idx, :, :]
            loss = losses[layer][idx, :, :]
            loss_binary = losses_binary[layer][idx, :, :]
            loss_binary_cleaned = losses_binary_cleaned[layer][idx, :, :]
            loss_binary_rebinned = losses_binary_rebinned[layer][idx, :, :]
    
            # initialize figure
            nrows = 1
            figheight = 6
            if do_extended_loss_plots:
                nrows = 2
                figheight = 12
            fig, axs = plt.subplots(ncols=4, nrows=nrows, figsize=(24, figheight), squeeze=False)
            
            # plot raw data
            fig, axs[0, 0] = plottools.plot_hist_2d(me_orig, fig=fig, ax=axs[0, 0],
                   title='Raw', titlesize=15,
                   xaxtitle=None, xaxtitlesize=None, yaxtitle=None, yaxtitlesize=None,
                   ticklabelsize=12, colorticklabelsize=12, extent=None, aspect=None,
                   docolorbar=True, caxtitle='Number of clusters',
                   caxtitlesize=15, caxtitleoffset=15,
                   origin='lower')
        
            # overlay automask
            if do_automasking:
                subsystem = f'BPix{layer}'
                automask_map = amreader.get_automask_map_for_ls(run, lumi, subsystem)
                ids = np.nonzero(automask_map.astype(int))
                for yidx, xidx in zip(ids[0], ids[1]):
                    linewidth = 1 if layer>=3 else 2
                    patch = mpl.patches.Rectangle((xidx-0.5, yidx-0.5), 1, 1,
                                      edgecolor='red', linewidth=linewidth,
                                      facecolor='none')
                    axs[0, 0].add_patch(patch)
        
            # plot preprocessed, reconstructed and loss
            fig, axs[0, 1] = plottools.plot_hist_2d(me_preprocessed, fig=fig, ax=axs[0, 1],
                   title='Input', titlesize=15,
                   xaxtitle=None, xaxtitlesize=None, yaxtitle=None, yaxtitlesize=None,
                   ticklabelsize=12, colorticklabelsize=12, extent=None, aspect=None,
                   docolorbar=True, caxtitle='Number of clusters\n(normalized)',
                   caxrange=(1e-6,2),
                   caxtitlesize=15, caxtitleoffset=30,
                   origin='lower')
            fig, axs[0, 2] = plottools.plot_hist_2d(me_pred, fig=fig, ax=axs[0, 2],
                   title='Reconstructed', titlesize=15,
                   xaxtitle=None, xaxtitlesize=None, yaxtitle=None, yaxtitlesize=None,
                   ticklabelsize=12, colorticklabelsize=12, extent=None, aspect=None,
                   docolorbar=True, caxtitle='Number of clusters\n(normalized)',
                   caxrange=(1e-6,2),
                   caxtitlesize=15, caxtitleoffset=30,
                   origin='lower')
            fig, axs[0, 3] = plottools.plot_hist_2d(loss, fig=fig, ax=axs[0, 3],
                   title='Loss', titlesize=15,
                   xaxtitle=None, xaxtitlesize=None, yaxtitle=None, yaxtitlesize=None,
                   ticklabelsize=12, colorticklabelsize=12, extent=None, aspect=None,
                   docolorbar=True, caxtitle='Loss',
                   caxrange=(0, 0.1),
                   caxtitlesize=15, caxtitleoffset=30,
                   origin='lower')
            
            # optional: plot more post-processing steps with the loss map
            if do_extended_loss_plots:
                fig, axs[1, 0] = plottools.plot_hist_2d(loss_binary, fig=fig, ax=axs[1, 0],
                   title=f'Binary loss', titlesize=15,
                   xaxtitle=None, xaxtitlesize=None, yaxtitle=None, yaxtitlesize=None,
                   ticklabelsize=12, colorticklabelsize=12, extent=None, aspect=None,
                   docolorbar=True, caxtitle='Loss',
                   caxrange=(0, 1),
                   caxtitlesize=15, caxtitleoffset=15,
                   origin='lower')
                fig, axs[1, 1] = plottools.plot_hist_2d(loss_binary_cleaned, fig=fig, ax=axs[1, 1],
                   title=f'Cleaned loss', titlesize=15,
                   xaxtitle=None, xaxtitlesize=None, yaxtitle=None, yaxtitlesize=None,
                   ticklabelsize=12, colorticklabelsize=12, extent=None, aspect=None,
                   docolorbar=True, caxtitle='Loss',
                   caxrange=(0, 1),
                   caxtitlesize=15, caxtitleoffset=15,
                   origin='lower')
                fig, axs[1, 2] = plottools.plot_hist_2d(loss_binary_rebinned, fig=fig, ax=axs[1, 2],
                   title=f'Rebinned loss', titlesize=15,
                   xaxtitle=None, xaxtitlesize=None, yaxtitle=None, yaxtitlesize=None,
                   ticklabelsize=12, colorticklabelsize=12, extent=None, aspect=None,
                   docolorbar=True, caxtitle='Loss',
                   caxrange=(0, 1),
                   caxtitlesize=15, caxtitleoffset=15,
                   origin='lower')
                fig.delaxes(axs[1, 3])
                
            
            # plot aesthetics
            plt.subplots_adjust(wspace=0.5)
            if str(layer)=='1': plt.subplots_adjust(hspace=-0.75)
            if str(layer)=='2': plt.subplots_adjust(hspace=-0.4)
            title = f'Run {run}, LS {lumi}, layer {layer}'
            axs[0, 0].text(0.01, 1.3, title, fontsize=15, transform=axs[0, 0].transAxes)
            plt.show()
            plt.close()
            
        # plot the combined loss map
        if do_combined_loss_plot:
            loss_binary_combined = losses_binary_combined[idx, :, :]
            fig, ax = plt.subplots()
            fig, ax = plottools.plot_hist_2d(loss_binary_combined, fig=fig, ax=ax,
                   title='Combined binary loss', titlesize=15,
                   xaxtitle=None, xaxtitlesize=None, yaxtitle=None, yaxtitlesize=None,
                   ticklabelsize=12, colorticklabelsize=12, extent=None, aspect=None,
                   docolorbar=True, caxtitle='Loss',
                   caxrange=(0, 1),
                   caxtitlesize=15, caxtitleoffset=15,
                   origin='lower')
            title = f'Run {run}, LS {lumi}'
            ax.text(0.01, 1.3, title, fontsize=15, transform=ax.transAxes)
            plt.show()
            plt.close()