In [None]:
# import sys
# import os
# if "google.colab" in sys.modules
#     os.chdir("/content/qk-spectral-analysis")
# else:
#     os.chdir("qk-spectral-analysis")

# Classification B plots

Generates boxplot of various statistics with layer number on the x-axis.
Results are stored in `{model_family}/plots/{statistic}`

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from statsmodels.tsa.stattools import bds
import pickle
import seaborn as sns
from collections import defaultdict
from enum import IntEnum
from scipy.interpolate import make_splrep
from pathlib import Path

sns.set_theme("talk")
sns.set_style("white")

In [2]:
class stats(IntEnum):
    MIN = 0
    MAX = 1
    MEAN = 2
    STD = 3
    SKEW = 4
    KURTOSIS = 5

In [None]:
# Some plots are distorted by layers 0 and 1 so the layers are skipped
start = defaultdict(lambda:2)

start['Llama-3.1-70B-Instruct'] = 5
start['Llama-3.3-70B-Instruct'] = 5

In [None]:
base = Path(".")
skip = []

for value in stats._member_names_:
    
    if value in skip:
            continue
        
    for f in base.glob("*/data/*.pkl"):
        
        # Larger width for Llama 3.1 405B 
        width = 20 if "405" not in str(f) else 35
        
        plt.figure(figsize=(width,8), dpi=100)
        family_name = f.parent.parent.name
        model_name = f.name.rsplit(".", maxsplit=1)[0]
        
        save_path = (base / family_name/ "plots"/ value)
        save_path.mkdir(exist_ok=True, parents=True)
        
        with open(f, "rb") as file:
            if value.upper() in stats._member_names_:
                data = dict(pickle.load(file))['eigen_values_stats']
                data = {k:v[stats[value.upper()]].squeeze() for k,v in data.items()}
            else:
                data = dict(pickle.load(file))[value]
                
        data = {k:v for k,v in data.items() if k >= start[model_name]}
        data = dict(sorted(data.items()))
        n_layers = len(data)
        
        # Compute the BDS statistic and p-value
        bds_stat, p_value = bds([np.median(x) for x in data.values()])
        
        plt.boxplot(list(data.values()), positions=np.arange(n_layers),
                    widths=0.5, showmeans=True, showfliers=False)
        
        if "405" not in str(f):
            plt.xticks(np.arange(n_layers), list(data.keys()))
        else:
            plt.xticks([])
        plt.title(f"{model_name}_{value} (BDS: {bds_stat:.2f} p-value: {p_value:.2e})")
        
        # Fit a spline of degree 1
        means = [x.mean() for x in data.values()]
        f_linear = make_splrep(np.arange(n_layers), means, s=20, k=1) # type: ignore
        x_new = np.linspace(0, n_layers-1, 500)
        y_new = f_linear(x_new)
        
        plt.plot(x_new, y_new, linestyle="--")
        
        plt.tight_layout()
        plt.savefig(save_path / f"{model_name}_{value}_box.png")
        # plt.show()
        plt.close()