In [None]:
# imports

import os
import sys
import json
import joblib
import importlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

thisdir = os.getcwd()
topdir = os.path.abspath(os.path.join(thisdir, '../../../'))
sys.path.append(topdir)

import tools.iotools as iotools
import tools.dftools as dftools
import plotting.plottools as plottools
from studies.clusters_2024.plotting.plot_cluster_occupancy import plot_cluster_occupancy
from studies.clusters_2024.nmf.modeldefs.nmf2d import NMF2D

In [None]:
# define eras and layers to plot

eras = [
    #'A-v1',
    #'B-v1',
    'C-v1',
    #'D-v1',
    #'E-v1',
    #'E-v2',
    #'F-v1',
    #'G-v1',
    #'H-v1',
    #'I-v1',
    #'I-v2',
    #'J-v1'
]

layers = [
    'BPix1',
    'BPix2',
    'BPix3',
    'BPix4'
]

In [None]:
# load models
    
# set model directory
modeldir = 'output_test'

# set path
nmf_files = {}
for era in eras:
    nmf_files[era] = {}
    for layer in layers:
        nmf_files[era][layer] = os.path.join(modeldir, f'nmf_model_{layer.upper()}_{era}.pkl')
    
# existence check
missing = []
for era in eras:
    for layer, f in nmf_files[era].items():
        if not os.path.exists(f): missing.append(f)
    if len(missing) > 0:
        raise Exception(f'The following files do not exist: {missing}')
    

# load models
nmfs = {}
for era in eras:
    nmfs[era] = {}
    for layer in layers:
        nmf_file = nmf_files[era][layer]
        nmf = joblib.load(nmf_file)
        nmfs[era][layer] = nmf

In [None]:
# plot model components

for era in eras:
    for layer in layers:
        
        C = nmfs[era][layer].components
        for idx in range(len(C)):
            fig, ax = plot_cluster_occupancy(C[idx],
                   xaxtitlesize=12, yaxtitlesize=12,
                   ticklabelsize=12, colorticklabelsize=12,
                   docolorbar=True, caxtitle='Number of clusters\n(normalized)',
                   caxrange=(1e-6, 2),
                   caxtitlesize=15, caxtitleoffset=35)
            title = f'NMF component {idx+1}'
            ax.text(0.01, 1.05, title, fontsize=15, transform=ax.transAxes)
            conditions = f'(2024-{era} {layer} NMF model)'
            ax.text(1., 1.05, conditions, fontsize=12, transform=ax.transAxes, ha='right')
        plt.show()
        plt.close()