In [None]:
# imports

import os
import sys
import json
import joblib
import importlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

thisdir = os.getcwd()
topdir = os.path.abspath(os.path.join(thisdir, '../../../'))
sys.path.append(topdir)

import tools.iotools as iotools
import tools.dftools as dftools
import plotting.plottools as plottools
from studies.pixel_clusters_2024.plotting.plot_cluster_occupancy import plot_cluster_occupancy
from studies.pixel_clusters_2024.nmf.modeldefs.nmf2d import NMF2D

In [None]:
# define eras and layers to plot

eras = [
    #'A-v1',
    #'B-v1',
    'C-v1',
    #'D-v1',
    #'E-v1',
    #'E-v2',
    #'F-v1',
    #'G-v1',
    #'H-v1',
    #'I-v1',
    #'I-v2',
    #'J-v1'
]

layers = [
    'BPix1',
    'BPix2',
    'BPix3',
    'BPix4'
]

In [None]:
# load models
    
# set model directory
modeldir = 'output_20250714_consolidation/models'

# set path
nmf_files = {}
for era in eras:
    nmf_files[era] = {}
    for layer in layers:
        nmf_files[era][layer] = os.path.join(modeldir, f'nmf_model_{layer.upper()}_{era}.pkl')
    
# existence check
missing = []
for era in eras:
    for layer, f in nmf_files[era].items():
        if not os.path.exists(f): missing.append(f)
    if len(missing) > 0:
        raise Exception(f'The following files do not exist: {missing}')
    

# load models
nmfs = {}
for era in eras:
    nmfs[era] = {}
    for layer in layers:
        nmf_file = nmf_files[era][layer]
        nmf = joblib.load(nmf_file)
        nmfs[era][layer] = nmf

In [None]:
# plot model components

for era in eras:
    for layer in layers:
        if layer!='BPix1': continue # only need plots of BPix1 for now
        
        C = nmfs[era][layer].components
        
        # change the order for more convenient description
        #ids = [0, 2, 4, 1, 3]
        #C = C[ids]
        
        counter = 0
        for idx in range(len(C)):
            #if idx==1 or idx==3: continue # pick the most interesting components
            counter += 1 
            fig, ax = plot_cluster_occupancy(C[idx],
                   xaxtitlesize=12, yaxtitlesize=12,
                   ticklabelsize=12, colorticklabelsize=12,
                   docolorbar=True, caxtitle='Number of clusters\n(normalized)',
                   caxrange=(1e-6, 2),
                   caxtitlesize=15, caxtitleoffset=35)
            title = f'NMF component {counter}'
            ax.text(0.01, 1.05, title, fontsize=15, transform=ax.transAxes)
            #conditions = f'(2024-{era} {layer} NMF model)'
            conditions = f'({layer} NMF model)'
            ax.text(1., 1.05, conditions, fontsize=12, transform=ax.transAxes, ha='right')
        plt.show()
        plt.close()

In [None]:
# plot model components
# (prettified version for DP note)

for era in eras:
    for layer in layers:
        if layer!='BPix1': continue # only need plots of BPix1 for now
        
        C = nmfs[era][layer].components
        
        # set zeros to small values
        # (in order to plot them as the bottom of the color scale
        # rather than white)
        C[C < 1e-6] = 1e-6
        
        # re-insert empty modules in the middle cross
        middle_1 = int(C.shape[1]/2)
        middle_2 = int(C.shape[2]/2)
        C = np.insert(C, [middle_1]*2, 0, axis=1)
        C = np.insert(C, [middle_2]*8, 0, axis=2)
        
        # initialize figure
        nrows = 2
        figheight = 12
        fig, axs = plt.subplots(ncols=3, nrows=nrows, figsize=(18, figheight), squeeze=False)
        
        fig, axs[0,0] = plot_cluster_occupancy(C[0], fig=fig, ax=axs[0, 0],
                   xaxtitlesize=15, yaxtitlesize=15,
                   ticklabelsize=12, colorticklabelsize=12,
                   docolorbar=True, caxtitle='Number of clusters\n(normalized)',
                   caxrange=(1e-6, 2),
                   caxtitlesize=15, caxtitleoffset=35)
        title = f'NMF component 1'
        axs[0, 0].text(0.5, 1.05, title, fontsize=15, transform=axs[0, 0].transAxes, ha='center')
        
        fig, axs[0,1] = plot_cluster_occupancy(C[1], fig=fig, ax=axs[0, 1],
                   xaxtitlesize=15, yaxtitlesize=15,
                   ticklabelsize=12, colorticklabelsize=12,
                   docolorbar=True, caxtitle='Number of clusters\n(normalized)',
                   caxrange=(1e-6, 2),
                   caxtitlesize=15, caxtitleoffset=35)
        title = f'NMF component 2'
        axs[0, 1].text(0.5, 1.05, title, fontsize=15, transform=axs[0, 1].transAxes, ha='center')
        
        fig, axs[0,2] = plot_cluster_occupancy(C[2], fig=fig, ax=axs[0, 2],
                   xaxtitlesize=15, yaxtitlesize=15,
                   ticklabelsize=12, colorticklabelsize=12,
                   docolorbar=True, caxtitle='Number of clusters\n(normalized)',
                   caxrange=(1e-6, 2),
                   caxtitlesize=15, caxtitleoffset=35)
        title = f'NMF component 3'
        axs[0, 2].text(0.5, 1.05, title, fontsize=15, transform=axs[0, 2].transAxes, ha='center')
        
        fig, axs[1,0] = plot_cluster_occupancy(C[3], fig=fig, ax=axs[1, 0],
                   xaxtitlesize=15, yaxtitlesize=15,
                   ticklabelsize=12, colorticklabelsize=12,
                   docolorbar=True, caxtitle='Number of clusters\n(normalized)',
                   caxrange=(1e-6, 2),
                   caxtitlesize=15, caxtitleoffset=35)
        title = f'NMF component 4'
        axs[1, 0].text(0.5, 1.05, title, fontsize=15, transform=axs[1, 0].transAxes, ha='center')
        
        fig, axs[1,1] = plot_cluster_occupancy(C[4], fig=fig, ax=axs[1, 1],
                   xaxtitlesize=15, yaxtitlesize=15,
                   ticklabelsize=12, colorticklabelsize=12,
                   docolorbar=True, caxtitle='Number of clusters\n(normalized)',
                   caxrange=(1e-6, 2),
                   caxtitlesize=15, caxtitleoffset=35)
        title = f'NMF component 5'
        axs[1, 1].text(0.5, 1.05, title, fontsize=15, transform=axs[1, 1].transAxes, ha='center')
        
        fig.delaxes(axs[1, 2])
        
        # plot aesthetics
        if str(layer)=='BPix1':
            plt.subplots_adjust(hspace=-0.7)
            plt.subplots_adjust(wspace=0.4)
        title = r'$\bf{CMS}$ ' + r'$\it{Preliminary}$'
        year = era[:4]
        if not year.startswith('202'): year = '2024' # older convention
        conditions = f'{year} (13.6 TeV)'
        if str(layer)=='BPix1':
            axs[0, 0].text(0.01, 1.3, title, fontsize=15, transform=axs[0, 0].transAxes)
            axs[0, 1].text(0.5, 1.3, layer, fontsize=15, ha='center', transform=axs[0,1].transAxes)
            axs[0, 2].text(0.99, 1.3, conditions, fontsize=15, ha='right', transform=axs[0,2].transAxes)
        
        plt.show()        
        plt.close()