In [None]:
# import sys
# if "google.colab" in sys.modules
#     os.chdir("/content/qk-spectral-analysis")
# else:
#     os.chdir("qk-spectral-analysis")

# Classification A plots

Computes the clsuters, UMAP features and gamma distribution parameters for a specified model family. 
The generated plots are in `{model_family}/ESDs` 

To recreate the plots from the post, use n_bins=24 and n_clusters=6.

In [None]:
n_clusters = 6
n_bins = 24
model_family = "Qwen3"

In [None]:
import pickle 
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
import umap
import seaborn as sns
from scipy.stats import gamma
from scipy.optimize import curve_fit
from pathlib import Path

sns.set_theme("talk")
sns.set_style("white")

In [None]:
data_location = Path(f"{model_family}/data")
save_loc = Path(f"{model_family}/ESDs")
save_loc.mkdir(exist_ok=True)

In [None]:
def jensen_shannon_divergence(p: np.ndarray, q: np.ndarray, eps: float = 1e-12) -> float:

    # Convert to probability distributions
    p = np.asarray(p, dtype=np.float64)
    q = np.asarray(q, dtype=np.float64)

    # Normalize to sum to 1
    p = p / (p.sum() + eps)
    q = q / (q.sum() + eps)

    # Midpoint distribution
    m = 0.5 * (p + q)

    # KL divergences (with safe log)
    kl_pm = np.sum(p * np.log((p + eps) / (m + eps)))
    kl_qm = np.sum(q * np.log((q + eps) / (m + eps)))

    # Jensen–Shannon divergence
    jsd = 0.5 * (kl_pm + kl_qm)
    return float(jsd)


In [None]:
for eignval_data_loc in data_location.glob("*.pkl"):
    
    with open(eignval_data_loc, "rb") as f:
        data = pickle.load(f)["eigen_values"]
    eigenvals = np.concatenate(list(data.values()), axis=0)
    
    # Normalize by dividing by mean
    normalized_eigen_vals = eigenvals/eigenvals.mean(axis=-1)[..., np.newaxis]
    
    # Define bins
    max_bin_edge = np.quantile(normalized_eigen_vals, 0.96)
    bin_edges=np.linspace(0.0,max_bin_edge,n_bins+1)
    
    # Compute normalized histograms
    hists = np.stack([np.histogram(x, bin_edges, density=True)[0]
                    for x in normalized_eigen_vals
                    ], 
                    axis=0)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
    
    # Identify clusters and assign classes
    kmeans = KMeans(n_clusters=n_clusters, random_state=0, n_init="auto").fit_predict(hists)
    
    # Compute UMAP features
    reducer = umap.UMAP(n_components=2, n_neighbors=10, min_dist=0.1)
    embedding = reducer.fit_transform(hists)
    
    # Compute mean and std for each cluster
    representative_hists = [hists[kmeans==x].mean(axis=0) for x in range(n_clusters)]
    hist_std = [hists[kmeans==x].std(axis=0) for x in range(n_clusters)]

    # create figure and subplots
    n_rows= int(3 + max((np.ceil((n_clusters - 6)/5)), 0))
    fig, axs = plt.subplots(n_rows, 5, figsize=(15,2.8*(n_rows)), gridspec_kw={'width_ratios': [0.65, 0.65, 0.65, 1, 1]})

    # Remove 3x3 plots from top-left to make space for UMAP    
    gs = axs[0,0].get_gridspec()
    for ax in axs[:3,:3].ravel():
        ax.remove()

    ax_big = fig.add_subplot(gs[:3,:3])
    ax_big.scatter(embedding[:,0], embedding[:,1], c=kmeans, s=10, cmap="Set2", alpha=0.75)
    ax_big.set_xticks([])
    ax_big.set_yticks([])
    ax_big.set_title("UMAP features")

    axs = axs.ravel()
    skip = [0,1,2,5,6,7,10,11,12]
    axs = [x for i, x in enumerate(axs) if i not in skip]
    
    # Plot each cluster in separate subplots 
    for i in range(n_clusters):

        axs[i].bar(range(n_bins), representative_hists[i])
        axs[i].set_xticks([])
        axs[i].set_yticks([])
        axs[i].errorbar(range(n_bins), representative_hists[i], yerr=hist_std[i], 
                        fmt='none', capsize=2.0, color="black", alpha=0.7, elinewidth=2)
        
        # Define gamma PDF wrapper
        def gamma_pdf(x, a, scale):
            return gamma.pdf(x, a, loc=0, scale=scale)

        # Fit (loc is fixed to 0)
        popt, _ = curve_fit(gamma_pdf, bin_centers[:], representative_hists[i][:], 
                            p0=(3.0, 0.5), sigma=representative_hists[i].std(), absolute_sigma=True)
        fit_alpha, fit_theta = popt
        pdf_vals = gamma_pdf(bin_centers, fit_alpha, fit_theta)
        
        # plot fitted distribution
        axs[i].plot(range(n_bins), pdf_vals, color="red", linestyle="--", lw=0.8)
        
        bottom, top = axs[i].get_ylim()
        left, right = axs[i].get_xlim()
        
        axs[i].set_ylim((0, top))
        
        # Compute JSD
        jsd = jensen_shannon_divergence(representative_hists[i], pdf_vals)
        
        axs[i].text(
        0.6, 0.9,                    
        f"α : {fit_alpha:.2f}\nθ : {fit_theta:.2f}",
        transform=axs[i].transAxes,
        ha='left', va='top',
        fontsize=16,
        bbox=dict(facecolor='white', alpha=0.6, edgecolor='none') # optional background
    )
        axs[i].set_title(f"Class_{i} (JSD:{jsd:.3f})")

    name = eignval_data_loc.name.rsplit(".", maxsplit=1)[0]
    fig.suptitle(name)

    plt.tight_layout()
    
    plt.savefig(save_loc / f"{name}.png")
    plt.close()