In [None]:
# imports

import os
import sys
import json
import joblib
import importlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

thisdir = os.getcwd()
topdir = os.path.abspath(os.path.join(thisdir, '../../../'))
sys.path.append(topdir)

import tools.iotools as iotools
import tools.dftools as dftools
import plotting.plottools as plottools
from tools.dataloadertools import MEDataLoader

from studies.clusters_2024.preprocessing.preprocessor import make_default_preprocessor
from studies.clusters_2024.nmf.modeldefs.nmf2d import NMF2D
from studies.clusters_2024.nmf.nmf_training import find_files, train

In [None]:
# set path to files

layer = 1
files = find_files(layer)

In [None]:
# define eras to use in training

eras = [
    #'A-v1',
    #'B-v1',
    #'C-v1',
    #'D-v1',
    #'E-v1',
    #'E-v2',
    #'F-v1',
    'F-v1-part3'
    #'G-v1',
    #'H-v1',
    #'I-v1',
    #'I-v2',
    #'J-v1'
]

In [None]:
# make preprocessors for the corresponding eras

preprocessors = {}
for era in eras:
    preprocessor_era = era
    if '-part' in preprocessor_era: preprocessor_era = era.split('-part')[0]
    preprocessors[era] = make_default_preprocessor(preprocessor_era, layer)

In [None]:
# make dataloaders for the corresponding eras
# (note: could make one big dataloader for eras together,
#  but this makes correct preprocessing a bit harder as it depends on the era,
#  so keep dataloaders separate per era for now)

dataloaders = {}
for era in eras:
    dataloaders[era] = MEDataLoader([files[era]])

In [None]:
# loop over eras for training

nmfs = {}
batch_size = 3000

do_plot_components = False

for era in eras:
    print(f'Now running on era {era}...')
    nrows = sum(dataloaders[era].nrows)
    nbatches = min(30, max(1, 3*int(nrows/batch_size)))
    print(f'Will train on {nbatches} batches of size {batch_size}.')
    
    # make the NMF model for this era
    nmf = NMF2D(n_components=15, forget_factor=1, batch_size=batch_size, verbose=True,
                tol=0.0, max_no_improvement=100, max_iter=1000,
                alpha_H=0.1)

    # training settings
    verbose = True
    min_entries = 0.5e6
    dataloader = dataloaders[era]
    preprocessor = preprocessors[era]
    
    # loop over random batches
    for batchidx in range(nbatches):

        # load batch
        if verbose: print(f'Now processing batch {batchidx+1} / {nbatches}...')
        df = dataloader.read_random_batch(batch_size=batch_size, mode='subbatched', num_subbatches=100)
        ndf = len(df)

        # filtering
        if min_entries is not None: df = df[df['entries'] > min_entries]
        if verbose: print(f'  Found {len(df)} / {ndf} instances passing filters.')
        if len(df)==0: continue

        # do preprocessing
        if preprocessor is not None:
            if verbose: print('  Preprocessing...')
            mes_preprocessed = preprocessor.preprocess(df)
        else:
            mes_preprocessed, _, _ = dftools.get_mes(df,
                                       xbinscolumn='x_bin', ybinscolumn='y_bin',
                                       runcolumn='run_number', lumicolumn='ls_number')
            
        # experimental: set zero-occupancy to 1 (average expected value after preprocessing)
        mes_preprocessed[mes_preprocessed==0] = 1

        # fit NMF
        if verbose: print('  Training NMF...')
        nmf.fit(mes_preprocessed)
        nmfs[era] = nmf
        
    # plot components
    if do_plot_components:
        C = nmf.components
        for idx in range(len(C)):
            fig, ax = plottools.plot_hist_2d(C[idx],
                   title=f'Component {idx+1}', titlesize=15,
                   xaxtitle=None, xaxtitlesize=None, yaxtitle=None, yaxtitlesize=None,
                   ticklabelsize=12, colorticklabelsize=12, extent=None, aspect=None,
                   docolorbar=True, caxtitle='Number of clusters\n(normalized)',
                   caxrange=(1e-6, 2),
                   caxtitlesize=15, caxtitleoffset=35,
                   origin='lower')
            title = me.split('-')[-1]
            ax.text(0.01, 1.3, title, fontsize=15, transform=ax.transAxes)
        plt.show()
        plt.close()

In [None]:
# save the models

dosave = False

if dosave:
    outputdir = f'models/PXLayer_{layer}'
    if not os.path.exists(outputdir): os.makedirs(outputdir)

    for era in eras:
        outputfile = os.path.join(outputdir, f'nmf_model_era{era}.pkl')
        joblib.dump(nmfs[era], outputfile)

In [None]:
# loop over eras for plotting random examples

nplot = 1

for era in eras:
    print(f'Now running on era {era}...')

    # calculate random indices
    print('Loading data...')
    df = iotools.read_parquet(files[era], columns=['run_number', 'ls_number', 'entries'])
    ids = np.arange(len(df))
    mask = (df['entries'].values > 0.5e6).astype(bool)
    print(f'Found {len(ids)} lumisections in this era, of which {np.sum(mask)} pass the selection.')
    ids = ids[mask]
    random_ids = np.random.choice(ids, size=nplot, replace=False)
    selected_run_numbers = df['run_number'].values[random_ids]
    selected_ls_numbers = df['ls_number'].values[random_ids]
    
    # alternative: select specific lumisections
    selected_lumis = [(383155, 1700)]
    selected_run_numbers = [el[0] for el in selected_lumis]
    selected_ls_numbers = [el[1] for el in selected_lumis]
    
    # load the data
    df = iotools.read_lumisections(files[era], selected_run_numbers, selected_ls_numbers)
    mes, runs, lumis = dftools.get_mes(df, xbinscolumn='x_bin', ybinscolumn='y_bin', runcolumn='run_number', lumicolumn='ls_number')
    
    # preprocess and predict
    print('Processing...')
    mes_preprocessed = preprocessors[era].preprocess(df)
    #mes_preprocessed, _, _ = dftools.get_mes(df,
    #                           xbinscolumn='x_bin', ybinscolumn='y_bin',
    #                           runcolumn='run_number', lumicolumn='ls_number')
    mes_pred = nmfs[era].predict(mes_preprocessed)
    losses = np.square(mes_preprocessed - mes_pred)

    print('Plotting...')
    for idx in range(len(df)):
        run = runs[idx]
        lumi = lumis[idx]
        me_orig = mes[idx, :, :]
        me_preprocessed = mes_preprocessed[idx, :, :]
        me_pred = mes_pred[idx, :, :]
        loss = losses[idx, :, :]
    
        fig, axs = plt.subplots(ncols=4, figsize=(24, 6))
        fig, axs[0] = plottools.plot_hist_2d(me_orig, fig=fig, ax=axs[0],
                   title='Raw', titlesize=15,
                   xaxtitle=None, xaxtitlesize=None, yaxtitle=None, yaxtitlesize=None,
                   ticklabelsize=12, colorticklabelsize=12, extent=None, aspect=None,
                   docolorbar=True, caxtitle='Number of clusters',
                   caxtitlesize=15, caxtitleoffset=15,
                   origin='lower')
        fig, axs[1] = plottools.plot_hist_2d(me_preprocessed, fig=fig, ax=axs[1],
                   title='Input', titlesize=15,
                   xaxtitle=None, xaxtitlesize=None, yaxtitle=None, yaxtitlesize=None,
                   ticklabelsize=12, colorticklabelsize=12, extent=None, aspect=None,
                   docolorbar=True, caxtitle='Number of clusters\n(normalized)',
                   caxrange=(1e-6,2),
                   caxtitlesize=15, caxtitleoffset=30,
                   origin='lower')
        fig, axs[2] = plottools.plot_hist_2d(me_pred, fig=fig, ax=axs[2],
                   title='Reconstructed', titlesize=15,
                   xaxtitle=None, xaxtitlesize=None, yaxtitle=None, yaxtitlesize=None,
                   ticklabelsize=12, colorticklabelsize=12, extent=None, aspect=None,
                   docolorbar=True, caxtitle='Number of clusters\n(normalized)',
                   caxrange=(1e-6,2),
                   caxtitlesize=15, caxtitleoffset=30,
                   origin='lower')
        fig, axs[3] = plottools.plot_hist_2d(loss, fig=fig, ax=axs[3],
                   title='Loss', titlesize=15,
                   xaxtitle=None, xaxtitlesize=None, yaxtitle=None, yaxtitlesize=None,
                   ticklabelsize=12, colorticklabelsize=12, extent=None, aspect=None,
                   docolorbar=True, caxtitle='Loss',
                   caxrange=(0.0, 0.1),
                   caxtitlesize=15, caxtitleoffset=30,
                   origin='lower')
        plt.subplots_adjust(wspace=0.5)
        title = f'Run {run}, LS {lumi}, layer {layer}'
        axs[0].text(0.01, 1.3, title, fontsize=15, transform=axs[0].transAxes)
        plt.show()
        plt.close()