In [2]:
from random import choices
import pickle
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from typing import List, Tuple
import warnings
warnings.filterwarnings("ignore") # Ignore all warnings

In [None]:
def get_Kronecker_Factors(N:int, layer_index:int, max_steps:int, choices_list:bool = None) -> Tuple:
    """
    Retrieves Kronecker factors for specified layers from serialized files.

    Parameters:
        N (int): Number of layers to retrieve.
        layer_index (int): Index of the layer to retrieve factors for.
        max_steps (int): Maximum step interval to sample from.
        choices_list (bool, optional): If provided, uses this list to sample steps. Otherwise, generates a new list.

    Returns:
        tuple: A tuple containing lists of H-bar and S matrices, the list of chosen steps, and the class name of the layer.
    
    Raises:
        FileNotFoundError: If the directories for H-bar or S are not found.
    """
    H_bar_dir = Path("H_bar_resnet").expanduser()
    S_dir = Path("S_resnet").expanduser()
    if not H_bar_dir.exists() or not S_dir.exists:
        raise FileNotFoundError(f'H_bar_resnet and S_resnet directories not found')
    
    if choices_list is None:
        choices_list = choices(range(0, max_steps + 100, 100), k=N)
    H_bar, S = [], []
    for i in choices_list:
        with open(f'H_bar_resnet/H_bar_{i}.pkl', 'rb') as f:
            dict_H_bar = pickle.load(f)
        H_bar_l = list(dict_H_bar.values())[layer_index]
        H_bar.append(H_bar_l.cpu().numpy())
        with open(f'S_resnet/S_{i}.pkl', 'rb') as f:
            dict_S = pickle.load(f)
        S_l = list(dict_S.values())[layer_index]
        layer_name = list(dict_S.keys())[layer_index].__class__.__name__
        S.append(S_l.cpu().numpy())
    return H_bar, S, choices_list, layer_name


def plot_Kronecker_Factors(H_bar:List, S:List, choices_list:List, N:int, layer_name:str, 
                           layer_index:int, verbose:bool = False):
    """
    Plots the Kronecker factors (H_bar and S matrices) for a specified layer.

    Parameters:
        H_bar (List): List of H_bar matrices to plot.
        S (List): List of S matrices to plot.
        choices_list (List): List of steps at which matrices were sampled.
        N (int): Number of layers to plot.
        layer_name (str): Name of the neural network layer.
        layer_index (int): Index of the layer in the model.
        verbose (bool, optional): If True, displays additional titles and information on the plots.

    This function creates a subplot with 2*N elements, displaying each matrix with global min and max values for better visualization.
    """
    # Compute global min and max for H_bar
    all_values_H_bar = np.concatenate([matrix.ravel() for matrix in H_bar])
    global_min_H_bar, global_max_H_bar = all_values_H_bar.min(), all_values_H_bar.max()
    
    # Compute global min and max for S
    all_values_S = np.concatenate([matrix.ravel() for matrix in S])
    global_min_S, global_max_S = all_values_S.min(), all_values_S.max()
    
    fig, axs = plt.subplots(1, 2*N, figsize=(15, 5*N), constrained_layout=True)

    def plot_single_matrix(matrix, ax, vmin, vmax):
        im = ax.imshow(matrix, cmap="gray", interpolation='nearest', vmin=vmin, vmax=vmax)
        ax.set_aspect('equal')
        ax.grid(False)  
        return im

    # Plot H_bar matrices
    for i, matrix in enumerate(H_bar):
        axs[i].set_title(f'Matrix $\\mathcal{{H}}$ | Step {choices_list[i]}', fontsize=16, fontweight='bold')
        im_A = plot_single_matrix(matrix, axs[i], global_min_H_bar, global_max_H_bar)

    # Plot S matrices
    for i, matrix in enumerate(S):
        axs[i+N].set_title(f'Matrix $\\mathcal{{S}}$ | Step {choices_list[i]}', fontsize=16, fontweight='bold')
        im_G = plot_single_matrix(matrix, axs[i+N], global_min_S, global_max_S)
    if verbose:
        plt.suptitle(f'Matrices $\\mathcal{{H}}$ and $\\mathcal{{S}}$ of {layer_name} layer at position {layer_index+1} of Resnet18', fontsize=20, fontweight='bold')

    # Add colorbars
    fig.colorbar(im_A, orientation='vertical', shrink=0.28)
    fig.colorbar(im_G, orientation='vertical', shrink=0.28)
    plt.subplots_adjust(wspace=0.1, hspace=0.3)
    plt.show()

In [None]:
# ---- Settings from Paper ----
N = 2  # Number of matrices for H_bar and for S
layer_index = 40 # 36 or 40
max_steps = 9800
choice_list = [5200, 9800]
# ---- Settings from Paper ----
H_bar, S, choices_list, layer_name = get_Kronecker_Factors(N, layer_index, max_steps, choice_list)
plot_Kronecker_Factors(H_bar, S, choices_list, N, layer_name, layer_index)