## C-H commutativiy

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt 
from scipy import stats
import os
from matplotlib.colors import LinearSegmentedColormap, SymLogNorm 


def set_publication_style():
    plt.rcParams.update({
        'font.family': 'serif',          
        'font.serif': ['Times New Roman', 'Times', 'DejaVu Serif'],
        'mathtext.fontset': 'stix',      
        'font.size': 20,                 
        'axes.labelsize': 20,            
        'axes.titlesize': 22,            
        'xtick.labelsize': 20,           
        'ytick.labelsize': 20,           
        'legend.fontsize': 12,          
        'figure.dpi': 300,               
        'savefig.dpi': 300,              
        'axes.linewidth': 2,           
        'lines.linewidth': 2,               
        'xtick.major.width': 2,       
        'ytick.major.width': 2,
        'xtick.direction': 'in',         
        'ytick.direction': 'in',
    })
set_publication_style()
def extract_diagonals(A, B, slice_range=None):
    if slice_range is None:
        slice_range = torch.arange(min(A.shape[0], B.shape[0]))
    A = torch.flip(A, dims=[0,1])
    B = torch.flip(B, dims=[0,1])
    A_dense = torch.diag_embed(torch.diag(A)[slice_range])
    B_dense = torch.diag_embed(torch.diag(B)[slice_range])
    return A_dense, B_dense

def compute_loglog_fit(A_dense, B_dense):
    diag_a = A_dense.diag().detach().cpu().numpy()
    diag_b = B_dense.diag().detach().cpu().numpy()
    mask = (diag_a > 0) & (diag_b > 0)
    if mask.sum() == 0:
        return None, None, None, None
    log_a, log_b = np.log10(diag_a[mask]), np.log10(diag_b[mask])
    slope, intercept, r_value, _, _ = stats.linregress(log_b, log_a)
    return log_b, log_a, slope, r_value**2

def compute_spearman_rank(A_dense, B_dense):
    diag_a = A_dense.diag().detach().cpu().numpy()
    diag_b = B_dense.diag().detach().cpu().numpy()
    # define mask for idx>2000
    # mask = np.arange(len(diag_a)) > 1000
    mask = (diag_a > 0 ) & (diag_b > 0)
    if mask.sum() < 2: return 0.0
    correlation, _ = stats.spearmanr(diag_a[mask], diag_b[mask])
    return correlation



def compute_commutativity_random_baseline(A, B):
    if A.ndim == 1: A = torch.diag(A)
    if B.ndim == 1: B = torch.diag(B)
    A, B = A.float().cpu(), B.float().cpu()

    N = B.shape[0]
    H = torch.randn(N, N, device=B.device)
    Q, _ = torch.linalg.qr(H)

    B_rand = Q @ B @ Q.T
    
    return compute_commutativity(A, B_rand)

def compute_commutativity(A, B):
    if A.ndim == 1: A = torch.diag(A)
    if B.ndim == 1: B = torch.diag(B)
    A, B = A.float().cpu(), B.float().cpu()
    if A.shape != B.shape: return None 
    AB = torch.matmul(A, B)
    BA = torch.matmul(B, A)
    commutator = AB - BA
    diff_norm = torch.norm(commutator, p='fro')
    base_norm = torch.norm(AB, p='fro')
    if base_norm == 0: return 0.0
    return (diff_norm / base_norm).item()

def compute_eigen_alignment(target_A, basis_B):
    if target_A.ndim == 1: target_A = torch.diag(target_A)
    if basis_B.ndim == 1: basis_B = torch.diag(basis_B)
    A, B = target_A.float().cpu(), basis_B.float().cpu()
    try:
        L_B, V_B = torch.linalg.eigh(B) 
    except:
        return None 
    A_rotated = V_B.T @ A @ V_B
    diag_energy = torch.norm(torch.diag(A_rotated), p=2)**2
    total_energy = torch.norm(A_rotated, p='fro')**2
    if total_energy == 0: return 0.0
    return (diag_energy / total_energy).item()

# def compute_split_commutativity(target_A, basis_B, top_k=100):
#     if target_A.ndim == 1: target_A = torch.diag(target_A)
#     if basis_B.ndim == 1: basis_B = torch.diag(basis_B)
#     A, B = target_A.float().cpu(), basis_B.float().cpu()
#     try:
#         L_B, V_B = torch.linalg.eigh(B)
#         V_top = V_B[:, -top_k:]  
#         P_top = V_top @ V_top.T
#         I = torch.eye(B.shape[0], device=B.device)
#         P_bulk = I - P_top
#         Comm = A @ B - B @ A
#         AB = A @ B 
#         num_top = torch.norm(P_top @ Comm @ P_top, p='fro')
#         den_top = torch.norm(P_top @ AB @ P_top, p='fro')
#         err_top = (num_top / (den_top )).item()
#         num_bulk = torch.norm(P_bulk @ Comm @ P_bulk, p='fro')
#         den_bulk = torch.norm(P_bulk @ AB @ P_bulk, p='fro')
#         err_bulk = (num_bulk / (den_bulk )).item()
#         return err_top, err_bulk
#     except Exception as e:
#         return None, None

def compute_split_commutativity(target_A, basis_B, k=100, m=1500):

    if target_A.ndim == 1: target_A = torch.diag(target_A)
    if basis_B.ndim == 1: basis_B = torch.diag(basis_B)
    A, B = target_A.float().cpu(), basis_B.float().cpu()
    
    try:
        L_B, V_B = torch.linalg.eigh(B)
        N = V_B.shape[0]
        
        if k >= N: k = N // 10
        if m > N: m = N
        if m <= k: m = k + 100 
        

        V_top = V_B[:, -k:]
        
        V_bulk = V_B[:, -m:-k]
        

        P_top = V_top @ V_top.T
        P_bulk = V_bulk @ V_bulk.T
        

        Comm = A @ B - B @ A
        AB = A @ B 
        

        num_top = torch.norm(P_top @ Comm @ P_top, p='fro')
        den_top = torch.norm(P_top @ AB @ P_top, p='fro')
        err_top = (num_top / (den_top )).item()
        

        num_bulk = torch.norm(P_bulk @ Comm @ P_bulk, p='fro')
        den_bulk = torch.norm(P_bulk @ AB @ P_bulk, p='fro')
        err_bulk = (num_bulk / (den_bulk)).item()
        
        return err_top, err_bulk
        
    except Exception as e:
        print(f"Error in windowed commutativity: {e}")
        return None, None



# def compute_scale_invariant_stats(target_A, basis_B):
#     if target_A.ndim == 1: target_A = torch.diag(target_A)
#     if basis_B.ndim == 1: basis_B = torch.diag(basis_B)
#     A, B = target_A.float().cpu(), basis_B.float().cpu()
#     try:
#         L_B, V_B = torch.linalg.eigh(B)
#         # ÈôçÂ∫èÊéíÂàó
#         idx_flip = torch.arange(L_B.shape[0] - 1, -1, -1)
#         V_B = V_B[:, idx_flip]
#         L_B = L_B[idx_flip]
        
#         # K: Raw projection
#         K = V_B.T @ A @ V_B
        
#         # R: Normalized correlation (signed)
#         diag_val = torch.abs(torch.diagonal(K))
#         norm_factor = torch.sqrt(torch.outer(diag_val, diag_val))
#         R = K / norm_factor 
#         R_abs = torch.abs(R) 

#         # Stats based on absolute value off-diagonal
#         N = R.shape[0]
#         mask_off = ~torch.eye(N, dtype=torch.bool)
#         off_diag_elements = R_abs[mask_off]
#         mean_coupling = torch.mean(off_diag_elements).item()
#         var_coupling = torch.var(off_diag_elements).item()
        
#         return mean_coupling, var_coupling, R, K, L_B
#     except Exception as e:
#         return None, None, None, None, None
def compute_scale_invariant_stats(target_A, basis_B):
    if target_A.ndim == 1: target_A = torch.diag(target_A)
    if basis_B.ndim == 1: basis_B = torch.diag(basis_B)
    
    A = target_A.double().cpu()
    B = basis_B.double().cpu()
    
    try:
        L_B, V_B = torch.linalg.eigh(B)
        
        idx_flip = torch.arange(L_B.shape[0] - 1, -1, -1)
        V_B = V_B[:, idx_flip]
        L_B = L_B[idx_flip]
        
        # K: Raw projection (V_B^T @ A @ V_B)
        K = V_B.T @ A @ V_B
        
        
        diag_val = torch.abs(torch.diagonal(K))

        norm_factor = torch.sqrt(torch.outer(diag_val, diag_val))
        

        eps = 1e-38
        

        safe_mask = norm_factor > eps
        
        R = torch.zeros_like(K)
        R[safe_mask] = K[safe_mask] / norm_factor[safe_mask]
        


        
        R_abs = torch.abs(R) 

        N = R.shape[0]
        mask_off = ~torch.eye(N, dtype=torch.bool)
        off_diag_elements = R_abs[mask_off]
        
        mean_coupling = torch.mean(off_diag_elements).item()
        var_coupling = torch.var(off_diag_elements).item()
        
        return mean_coupling, var_coupling, R.float(), K.float(), L_B.float()
        
    except Exception as e:
        print(f"Error in stats calculation: {e}")
        return None, None, None, None, None




def compute_rmt_baseline_alignment(target_A):
    if target_A.ndim == 1: target_A = torch.diag(target_A)
    A = target_A.float().cpu()
    N = A.shape[0]
    H = torch.randn(N, N)
    Q, _ = torch.linalg.qr(H)
    A_rand = Q.T @ A @ Q
    diag_energy = torch.norm(torch.diag(A_rand), p=2)**2
    total_energy = torch.norm(A_rand, p='fro')**2
    if total_energy == 0: return 0.0
    return (diag_energy / total_energy).item()


def analyze_epochs_advanced(epoch_list, base_path, slice_range, var_pairs):
    results = {
        "slope": {k: [] for k in var_pairs},
        "spearman": {k: [] for k in var_pairs},
        "commutativity": {k: [] for k in var_pairs},
        "comm_random": {k: [] for k in var_pairs},
        "alignment": {k: [] for k in var_pairs}, 
        "rmt_alignment": {k: [] for k in var_pairs},
        "comm_top": {k: [] for k in var_pairs},
        "comm_bulk": {k: [] for k in var_pairs},
        "mean_coupling": {k: [] for k in var_pairs},
        "var_coupling": {k: [] for k in var_pairs}
    }

    last_epoch_data = {}

    for epoch in epoch_list:
        file_path = f"{base_path}{epoch}.pt"
        if not os.path.exists(file_path): continue
        try:
            loaded_data = torch.load(file_path, map_location='cpu')
        except: continue

        try:
            L_H, V_H = torch.linalg.eigh(torch.tensor(loaded_data["Hessian"]))
        except:
            V_H = None

        for name, (key_a, key_b) in var_pairs.items():
            if key_a not in loaded_data or key_b not in loaded_data: continue
            
            A_raw = torch.tensor(loaded_data[key_a]).float()
            B_raw = torch.tensor(loaded_data[key_b]).float()
            
            A, B = A_raw, B_raw
            if V_H is not None:
                 if "covar" in key_a.lower(): A = V_H.transpose(-2, -1) @ A @ V_H
                 if "covar" in key_b.lower(): B = V_H.transpose(-2, -1) @ B @ V_H
                 if "hessian" in key_a.lower(): A = V_H.transpose(-2, -1) @ A @ V_H
                 if "hessian" in key_b.lower(): B = V_H.transpose(-2, -1) @ B @ V_H

            A_dense, B_dense = extract_diagonals(A, B, slice_range)
            _, _, slope, r2 = compute_loglog_fit(A_dense, B_dense)
            spearman = compute_spearman_rank(A_dense, B_dense)
            
            if slope is not None: 
                results["slope"][name].append((epoch, slope))
                results["spearman"][name].append((epoch, spearman))

            comm_val = compute_commutativity(A, B)
            if comm_val is not None: 
                results["commutativity"][name].append((epoch, comm_val))
                comm_rand = compute_commutativity_random_baseline(A, B)
                results["comm_random"][name].append((epoch, comm_rand))

            align_val = compute_eigen_alignment(target_A=A, basis_B=B)
            if align_val is not None: results["alignment"][name].append((epoch, align_val))
            
            rmt_val = compute_rmt_baseline_alignment(target_A=A)
            results["rmt_alignment"][name].append((epoch, rmt_val))
            
            err_top, err_bulk = compute_split_commutativity(target_A=A, basis_B=B, k=200, m=1000)
            if err_top is not None:
                results["comm_top"][name].append((epoch, err_top))
                results["comm_bulk"][name].append((epoch, err_bulk))

            mean_val, var_val, R_signed, K_matrix, L_B = compute_scale_invariant_stats(target_A=A, basis_B=B)
            
            if mean_val is not None:
                results["mean_coupling"][name].append((epoch, mean_val))
                results["var_coupling"][name].append((epoch, var_val))
                
                if epoch == epoch_list[-1]:
                    if name not in last_epoch_data: last_epoch_data[name] = {}
                    last_epoch_data[name]['K_matrix'] = K_matrix
                    last_epoch_data[name]['R_matrix_signed'] = R_signed
                    last_epoch_data[name]['mean_off_diag_real'] = mean_val
                    last_epoch_data[name]['var_off_diag_real'] = var_val
                    last_epoch_data[name]['A_raw'] = A_raw
                    last_epoch_data[name]['B_raw'] = B_raw

        del loaded_data

    return results, last_epoch_data

def plot_all_metrics(results, config):
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    axes = axes.flatten()
    
    pair_names = list(results["slope"].keys())
    colors = plt.cm.tab10(np.linspace(0, 1, len(pair_names)))
    color_map = {name: colors[i] for i, name in enumerate(pair_names)}

    ax = axes[0]
    for name, data in results["slope"].items():
        if not data: continue
        data.sort()
        epochs, vals = zip(*data)
        if "C_" in name:
            dname = "$\\mathbf{C}_{AWD,raw}$"
        elif "Covar_" in name:
            dname = "$\\mathbf{Covar}$"
        elif "C1_vs" in name:
            dname = "$\\mathbf{C}^{hh}$"
        elif "C1_dia_vs" in name:
            dname = "$\\mathbf{C}^{hh,SD}$"
        elif "C1_dia_w_dia_vs" in name:
            dname = "$\\mathbf{C}^{hh,SD,WD}$"
        elif "H2" in name:
            dname = "$2\\mathbf{C}/\\sigma_w^2$"
        ax.plot(epochs, vals, marker='o', label=dname, color=color_map[name])
    ax.axhline(y=1, color='r', linestyle='--', linewidth=1.5, label=f"Lower bound = {1}")
    ax.axhline(y=2, color='g', linestyle='--', linewidth=1.5, label=f"Upper bound = {2}")
    ax.set_title("Power vs Epochs")
    ax.set_ylabel("$\gamma$")
    ax.set_ylim(0.7, 2.3)
    ax.grid(True, linestyle='--', alpha=0.6)
    ax.legend(fontsize='small')




    ax = axes[1]
    
    ax.axhline(y=1.0, color='gray', linestyle='-.', linewidth=1, alpha=0.5, label='Random Baseline (y=1)')
    
    for name in pair_names:
        d_full = sorted(results["commutativity"].get(name, []), key=lambda x: x[0])
        d_top = sorted(results["comm_top"].get(name, []), key=lambda x: x[0])
        d_bulk = sorted(results["comm_bulk"].get(name, []), key=lambda x: x[0])
        d_rand = sorted(results["comm_random"].get(name, []), key=lambda x: x[0])
        
        if not d_rand: continue
        
        rand_map = {e: v for e, v in d_rand}
        
        c = color_map[name]
        
        def get_ratio(data_list):
            return [(e, v / rand_map[e]) for e, v in data_list if e in rand_map]
        
        r_full = get_ratio(d_full)
        r_top = get_ratio(d_top)
        r_bulk = get_ratio(d_bulk)
        
        if r_full: 
            ax.plot(*zip(*r_full), color=c, linestyle='-', linewidth=2, label=name) 
        if r_top: 
            ax.plot(*zip(*r_top), color=c, linestyle='--', linewidth=1.5, alpha=0.8)
        if r_bulk: 
            ax.plot(*zip(*r_bulk), color=c, linestyle=':', linewidth=2, alpha=0.9)

    ax.set_title("Normalized Commutativity Error\n(Ratio = Error / Random_Baseline)")
    ax.set_ylabel("Ratio (Log Scale)")
    ax.set_yscale('linear') 
    
    
    ax.grid(True, linestyle='--', alpha=0.6, which='both') # both grid for log scale
    
    from matplotlib.lines import Line2D
    custom_lines = [
        Line2D([0], [0], color='black', lw=2, linestyle='-'),
        Line2D([0], [0], color='black', lw=1.5, linestyle='--'),
        Line2D([0], [0], color='black', lw=2, linestyle=':'),
        Line2D([0], [0], color='gray', lw=1, linestyle='-.')
    ]
    ax.legend(custom_lines, ['Full Ratio', 'Top-K Ratio', 'Bulk Ratio', 'Random (y=1)'], loc='upper right', fontsize='small')




    ax = axes[2]
    for name, data in results["spearman"].items():
        if not data: continue
        data.sort()
        epochs, vals = zip(*data)
        ax.plot(epochs, vals, marker='s', markersize=4, label=name, color=color_map[name])
    ax.set_title("Spearman Rank Correlation")
    ax.set_ylim(-0.1, 1.1)
    ax.grid(True, linestyle='--', alpha=0.6)

    ax = axes[3]
    for name, data in results["alignment"].items():
        if not data: continue
        data.sort()
        epochs, vals = zip(*data)
        ax.plot(epochs, vals, marker='o', markersize=4, label=name, color=color_map[name])
        rmt_data = results["rmt_alignment"].get(name, [])
        if rmt_data:
            rmt_data.sort()
            ax.plot(*zip(*rmt_data), linestyle='--', alpha=0.5, color=color_map[name])
    ax.set_title("Eigen Alignment (Solid=Real, Dashed=Random)")
    ax.set_ylim(-0.05, 1.05)
    ax.grid(True, linestyle='--', alpha=0.6)

    for ax in axes: ax.set_xlabel("Epoch")
    plt.tight_layout()


    save_dir = f"ICML_Figures/{config['model']}_{config['dataset']}_{config['lss_fn']}"
    os.makedirs(save_dir, exist_ok=True)
    filename = f"All_Metrics.pdf"   
    save_path = os.path.join(save_dir, filename)
    plt.savefig(
            save_path, 
            format='pdf',         
            bbox_inches='tight',   
            pad_inches=0.05,       
            dpi=100               
        )
    plt.show()






def plot_snapshot_deep_dive_combined(last_epoch_data, config):
    cmap_choice = 'RdBu_r'
    
    names = list(last_epoch_data.keys())
    n_rows = len(names)
    n_cols = 3
    
    if n_rows == 0:
        print("No data to plot.")
        return


    figsize = (18, 5.0 * n_rows)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize, constrained_layout=True)


    if n_rows == 1:
        axes = np.array([axes])

    print(f"Generating combined plot for {n_rows} items...")

    for idx, (name, data) in enumerate(last_epoch_data.items()):
        
        K_real = data['K_matrix']
        R_real = data['R_matrix_signed']
        mean_real = data['mean_off_diag_real']
        var_real = data['var_off_diag_real']
        A_raw = data['A_raw']
        B_raw = data['B_raw']

        try:
            L_A, _ = torch.linalg.eigh(A_raw)
            N = A_raw.shape[0]
            H = torch.randn(N, N)
            Q, _ = torch.linalg.qr(H)
            A_rand = Q @ torch.diag(L_A) @ Q.T
            mean_rand, var_rand, R_rand, _, _ = compute_scale_invariant_stats(A_rand, B_raw)
        except:
            R_rand = None
            mean_rand, var_rand = 0, 0

        disp_dim = min(2560, R_real.shape[0])


        ax = axes[idx, 0]
        ax.grid(False)
        k_data = K_real[:300, :300].numpy()
        max_val_k = np.max(np.abs(k_data))
        
        linthresh = max_val_k * 1e-3 if max_val_k > 0 else 1e-5
        norm = SymLogNorm(linthresh=linthresh, linscale=0.5, vmin=-max_val_k*5e-2, vmax=max_val_k*5e-2, base=10)
        
        im0 = ax.imshow(k_data, cmap=cmap_choice, norm=norm, interpolation='nearest')
        
        cbar0 = fig.colorbar(im0, ax=ax, fraction=0.046, pad=0.04)
        cbar0.set_label('Amplitude (SymLog)', weight='bold')

        # LaTeX Title Logic
        if "C_" in name: dname = "$\\mathbf{C}_{AWD,raw}$"
        elif "Covar_" in name: dname = "$\\mathbf{Covar}$"
        elif "C1_vs" in name: dname = "$\\mathbf{C}^{hh}$"
        elif "C1_dia_vs" in name: dname = "$\\mathbf{C}^{hh,SD}$"
        elif "C1_dia_w_dia_vs" in name: dname = "$\\mathbf{C}^{hh,SD,WD}$"
        elif "H2" in name: dname = "$2\\mathbf{C}/\\sigma_w^2$"
        else: dname = name.replace("_", " ") # Fallback

        
        ax.set_title(dname, pad=10)
        ax.set_ylabel("Basis Index") 
        
        
        if idx == n_rows - 1:
            ax.set_xlabel("Basis Index")
        else:
            ax.set_xlabel("")

        # =========================================================
        # Plot 2: Real Correlation R [Column 1]
        # =========================================================
        ax = axes[idx, 1]
        ax.grid(False)
        r_data = R_real[:disp_dim, :disp_dim].numpy()
        im1 = ax.imshow(r_data, cmap=cmap_choice, vmin=-1, vmax=1, interpolation='nearest')
        
        cbar1 = fig.colorbar(im1, ax=ax, fraction=0.046, pad=0.04)
        cbar1.set_label('Correlation', weight='bold')
        
        ax.set_title(rf"Normalized $R$ (Real)" + f"\n$\mu={mean_real:.4f} \mid \sigma^2={var_real:.2e}$", pad=10)
        ax.set_yticks([]) #
        
        if idx == n_rows - 1:
            ax.set_xlabel("Basis Index")

        # =========================================================
        # Plot 3: Random Baseline R [Column 2]
        # =========================================================
        ax = axes[idx, 2]
        ax.grid(False)
        if R_rand is not None:
            r_rand_data = R_rand[:disp_dim, :disp_dim].numpy()
            im2 = ax.imshow(r_rand_data, cmap=cmap_choice, vmin=-1, vmax=1, interpolation='nearest') 
            
            cbar2 = fig.colorbar(im2, ax=ax, fraction=0.046, pad=0.04)
            cbar2.set_label('Correlation', weight='bold')
            
            ax.set_title(rf"Normalized $R$ (Random)" + f"\n$\mu={mean_rand:.4f} \mid \sigma^2={var_rand:.2e}$", pad=10)
            ax.set_yticks([]) 
            
            if idx == n_rows - 1:
                ax.set_xlabel("Basis Index")
        else:
            ax.axis('off')


    save_dir = f"ICML_Figures/{config['model']}_{config['dataset']}_{config['lss_fn']}"
    os.makedirs(save_dir, exist_ok=True)
    

    combined_name = "Combined_Deep_Dive_Snapshot.pdf"
    save_path = os.path.join(save_dir, combined_name)
    
    plt.savefig(
        save_path, 
        format='pdf',
        bbox_inches='tight',
        pad_inches=0.1,
        dpi=100
    )
    
    print(f"Saved combined vector figure to: {save_path}")
    
    plt.show()
    plt.close()
def plot_snapshot_deep_dive_v2(last_epoch_data, config):
    cmap_choice = 'RdBu_r' 

    for name, data in last_epoch_data.items():
        print(f"\n--- Deep Dive for {name} ---")
        
        K_real = data['K_matrix']
        R_real = data['R_matrix_signed']
        mean_real = data['mean_off_diag_real']
        var_real = data['var_off_diag_real']
        A_raw = data['A_raw']
        B_raw = data['B_raw']

        try:
            L_A, _ = torch.linalg.eigh(A_raw)
            N = A_raw.shape[0]
            H = torch.randn(N, N)
            Q, _ = torch.linalg.qr(H)
            A_rand = Q @ torch.diag(L_A) @ Q.T
            mean_rand, var_rand, R_rand, _, _ = compute_scale_invariant_stats(A_rand, B_raw)
        except:
            R_rand = None
            mean_rand, var_rand = 0, 0

        fig, axes = plt.subplots(1, 3, figsize=(18, 5.5), constrained_layout=True)
        
        disp_dim = min(2560, R_real.shape[0])
        
        # ------------------------------------------------
        # Plot 1: Raw Matrix K (SymLogNorm)
        # ------------------------------------------------
        ax = axes[0]
        ax.grid(False)
        k_data = K_real[:300, :300].numpy()
        max_val_k = np.max(np.abs(k_data))
        
        linthresh = max_val_k * 1e-3 if max_val_k > 0 else 1e-5
        norm = SymLogNorm(linthresh=linthresh, linscale=0.5, vmin=-max_val_k*5e-2, vmax=max_val_k*5e-2, base=10)
        
        im0 = ax.imshow(k_data, cmap=cmap_choice, norm=norm, interpolation='nearest')
        
        cbar0 = fig.colorbar(im0, ax=ax, fraction=0.046, pad=0.04)
        cbar0.set_label('Amplitude (SymLog)', weight='bold')
        
        if "C_" in name:
            dname = "$\\mathbf{C}_{AWD,raw}$"
        elif "Covar_" in name:
            dname = "$\\mathbf{Covar}$"
        elif "C1_vs" in name:
            dname = "$\\mathbf{C}^{hh}$"
        elif "C1_dia_vs" in name:
            dname = "$\\mathbf{C}^{hh,SD}$"
        elif "C1_dia_w_dia_vs" in name:
            dname = "$\\mathbf{C}^{hh,SD,WD}$"
        elif "H2" in name:
            dname = "$2\\mathbf{C}/\\sigma_w^2$"
        ax.set_title(dname, pad=10)
        # ax.set_title(rf"Raw Covariance " + "\n(Log Scale)", pad=10)

        ax.set_xlabel("Basis Index")
        ax.set_ylabel("Basis Index ")

        # ------------------------------------------------
        # Plot 2: Real Correlation R
        # ------------------------------------------------
        ax = axes[1]
        ax.grid(False)
        r_data = R_real[:disp_dim, :disp_dim].numpy()
        im1 = ax.imshow(r_data, cmap=cmap_choice, vmin=-1, vmax=1, interpolation='nearest')
        
        cbar1 = fig.colorbar(im1, ax=ax, fraction=0.046, pad=0.04)
        cbar1.set_label('Correlation', weight='bold')
        
        ax.set_title(rf"Normalized $R$ (Real)" + f"\n$\mu={mean_real:.4f} \mid \sigma^2={var_real:.2e}$", pad=10)
        ax.set_xlabel("Basis Index")
        # ax.set_ylabel("Basis Eigenmodes") 
        ax.set_yticks([]) 

        # ------------------------------------------------
        # Plot 3: Random Baseline R
        # ------------------------------------------------
        ax = axes[2]
        ax.grid(False)
        if R_rand is not None:
            r_rand_data = R_rand[:disp_dim, :disp_dim].numpy()
            im2 = ax.imshow(r_rand_data, cmap=cmap_choice, vmin=-1, vmax=1, interpolation='nearest') 
            
            cbar2 = fig.colorbar(im2, ax=ax, fraction=0.046, pad=0.04)
            cbar2.set_label('Correlation', weight='bold')
            
            ax.set_title(rf"Normalized $R$ (Random)" + f"\n$\mu={mean_rand:.4f} \mid \sigma^2={var_rand:.2e}$", pad=10)
            ax.set_xlabel("Basis Index")
            ax.set_yticks([]) 
        else:
            ax.axis('off')


        # plt.plot(np.diag(R_real), color='black', linewidth=0.5)  
        # plt.suptitle(f"Deep Dive Snapshot: {name}", fontsize=16)

        save_dir = f"ICML_Figures/{config['model']}_{config['dataset']}_{config['lss_fn']}"
        os.makedirs(save_dir, exist_ok=True)
        filename = f"{name}.pdf"
        save_path = os.path.join(save_dir, filename)
        
        plt.savefig(
            save_path, 
            format='pdf',         
            bbox_inches='tight',   
            pad_inches=0.05,       
            dpi=100                
        )
        
        print(f"Saved vector figure to: {save_path}")
        
        plt.show() 
        plt.close() 




if __name__ == "__main__":
    epoch_list =  [1, 10, 30, 50, 70, 80, 90, 100]#, 150, 200] 
    slice_range = torch.arange(0, 1500)
    train_size = 2000
    sample_number = 20

    net_size = 50
    n_class = 10
    config = {}
    config['lss_fn'] = 'mse'
    config['dataset'] = 'mnist' 
    config['model'] = 'FC' 
    config['net_size'] = net_size
    config['sample_holder'] = [i for i in range(n_class)]
    config['B'] = 50
    config['alpha'] = 0.1
    save_dir = f"./AWCH_data/TrainSize{train_size}_SampleN{sample_number}_ClassN{len(config['sample_holder'])}_B{config['B']}lr{config['alpha']}_lossfn_{config['lss_fn']}_model_{config['model']}_dataset_{config['dataset']}"
    # save_dir = f"./AWCH_data/NS{net_size}_TrainSize{train_size}_SampleN{sample_number}_ClassN{len(config['sample_holder'])}_B{config['B']}lr{config['alpha']}_lossfn_{config['lss_fn']}_model_{config['model']}_dataset_{config['dataset']}"

    file_name1 = "C_epoch_" 
    base_path1 = os.path.join(save_dir, file_name1)
    
    var_pairs = {


        "C_vs_H1": ("C", "H_1_d"),
        "C1_vs_H1": ("C1", "H_1_d"),
        "C1_dia_vs_H1": ("C1_dia", "H_1_d"),
        "C1_dia_w_dia_vs_H1": ("C1_dia_w_dia", "H_1_d"),
        "Covar_vs_Hessian": ("Covar", "Hessian"),
        "H2_vs_H1": ("H_2_d", "H_1_d"),

    }

    full_results, last_epoch_data = analyze_epochs_advanced(epoch_list, base_path1, slice_range, var_pairs)
    plot_all_metrics(full_results, config)
    if last_epoch_data:
        plot_snapshot_deep_dive_combined(last_epoch_data, config)
    pass

    file_path1 = f"{base_path1}{100}.pt"
    loaded_data1 = torch.load(file_path1)
    
    if 'train_loss_holder' in loaded_data1:
        train_loss_holder = loaded_data1['train_loss_holder']
        test_loss_holder = loaded_data1['test_loss_holder']
        train_accuracy_holder = loaded_data1['train_accuracy_holder']
        test_accuracy_holder = loaded_data1['test_accuracy_holder']
        plt.plot(np.log10(train_loss_holder))
        plt.plot(np.log10(test_loss_holder))
        plt.title('loss')
        plt.legend(['train','test'])
        plt.show()
        plt.plot(train_accuracy_holder)
        plt.plot(test_accuracy_holder)
        plt.title('accuracy')
        plt.legend(['train','test'])
        plt.show()
# ============================

## log-log Plot

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import os
import matplotlib.cm as cm


def set_publication_style():
    plt.rcParams.update({
        'font.family': 'serif',          
        'font.serif': ['Times New Roman', 'Times', 'DejaVu Serif'],
        'mathtext.fontset': 'stix',      
        'font.size': 20,                 
        'axes.labelsize': 20,           
        'axes.titlesize': 22,           
        'xtick.labelsize': 20,        
        'ytick.labelsize': 20,         
        'legend.fontsize': 12,        
        'figure.dpi': 300,            
        'savefig.dpi': 300,              
        'axes.linewidth': 2,         
        'lines.linewidth': 2,           
        'xtick.major.width': 2,        
        'ytick.major.width': 2,
        'xtick.direction': 'in',       
        'ytick.direction': 'in',
    })
set_publication_style()

def extract_diagonals(A, B, slice_range=None):
    if slice_range is None:
        slice_range = torch.arange(min(A.shape[0], B.shape[0]))
    A = torch.flip(A, dims=[0,1])
    B = torch.flip(B, dims=[0,1])
    A_dense = torch.diag_embed(torch.diag(A)[slice_range])
    B_dense = torch.diag_embed(torch.diag(B)[slice_range])
    return A_dense, B_dense

def compute_loglog_fit(A_dense, B_dense):
    diag_a = A_dense.diag().detach().cpu().numpy()
    diag_b = B_dense.diag().detach().cpu().numpy()
    
    mask = (diag_a > 0) & (diag_b > 0)
    if mask.sum() < 10:
        return None, None, None, None
        
    log_a, log_b = np.log10(diag_a[mask]), np.log10(diag_b[mask])
    
    slope, intercept, r_value, _, _ = stats.linregress(log_b, log_a)
    return log_b, log_a, slope, r_value**2

def load_and_process_data(config, epoch, slice_range, var_pairs):
    
    train_size = config.get('train_size')
    sample_number = config.get('sample_number')
    class_num = config.get('class_number')
    
    save_dir = (f"./AWCH_data/TrainSize{train_size}_SampleN{sample_number}_"
                f"ClassN{class_num}_B{config['B']}lr{config['alpha']}_"
                f"lossfn_{config['lss_fn']}_model_{config['model']}_dataset_{config['dataset']}")

    base_path1 = os.path.join(save_dir, "C_epoch_") 
    base_path2 = os.path.join(save_dir, "C_epoch_") 

    file_path1 = f"{base_path1}{epoch}.pt"
    file_path2 = f"{base_path2}{epoch}.pt"

    if not os.path.exists(file_path1):
        print(f"‚ö†Ô∏è [Missing] {config['label']} -> Path not found: {save_dir}")
        return {}

    try:
        loaded_data1 = torch.load(file_path1, map_location='cpu')
        loaded_data2 = torch.load(file_path2, map_location='cpu')
        loaded_data = {**loaded_data1, **loaded_data2}
        del loaded_data1, loaded_data2
    except Exception as e:
        print(f"‚ùå [Error] Loading {config['label']}: {e}")
        return {}

    L, V = None, None
    if "Hessian" in loaded_data:
        try:
            H_tensor = torch.tensor(loaded_data["Hessian"])
            L, V = torch.linalg.eigh(H_tensor.double()) 
            V = V.float()
        except Exception as e:
            print(f"‚ö†Ô∏è Eigen decomposition failed: {e}")

    processed_results = {}

    for name, (key_a, key_b) in var_pairs.items():
        if key_a not in loaded_data or key_b not in loaded_data:
            continue

        A = torch.tensor(loaded_data[key_a])
        B = torch.tensor(loaded_data[key_b])

        if V is not None:
            need_rotate_a = any(sub in key_a.lower() for sub in ["covar", "hessian"])
            need_rotate_b = any(sub in key_b.lower() for sub in ["covar", "hessian"])
            
            if need_rotate_a: A = V.transpose(-2, -1) @ A @ V
            if need_rotate_b: B = V.transpose(-2, -1) @ B @ V

        A_dense, B_dense = extract_diagonals(A, B, slice_range)
        log_b, log_a, slope, r2 = compute_loglog_fit(A_dense, B_dense)

        if slope is not None:
            log_a_centered = log_a - log_a.mean()
            log_b_centered = log_b - log_b.mean()
            
            processed_results[name] = {
                'x': log_b_centered,
                'y': log_a_centered,
                'slope': slope,
                'r2': r2,
                'label': config['label'],
                'color': config.get('color', 'black')
            }

    del loaded_data
    return processed_results

def plot_shifted_loglog(results_list, var_pairs_to_plot, y_shift_step=1.5):


    num_plots = len(var_pairs_to_plot)
    fig, axes = plt.subplots(1, num_plots, figsize=(6 * num_plots, 6))
    
    if num_plots == 1:
        axes = [axes]

    for idx, (pair_name, (key_a, key_b)) in enumerate(var_pairs_to_plot.items()):
        ax = axes[idx]
        
        all_x = []
        all_y = []

        for config_idx, res_dict in enumerate(results_list):
            if pair_name not in res_dict:
                continue
            
            data = res_dict[pair_name]
            x, y = data['x'], data['y']
            slope, r2 = data['slope'], data['r2']
            color = data['color']
            label = data['label']

            current_shift = config_idx * y_shift_step
            y_shifted = y - current_shift
            
            ax.scatter(x, y_shifted, s=6, alpha=0.25, color=color, edgecolors='none')
            
            x_fit = np.array([x.min(), x.max()])
            y_fit = slope * x_fit - current_shift 

            label_txt = f"{label}"
            ax.plot(x_fit, y_fit, '-', lw=1, color=color, label=label_txt)

            all_x.extend([x.min(), x.max()])
            all_y.extend([y_shifted.min(), y_shifted.max()])

        if all_x:
            min_x, max_x = min(all_x), max(all_x)
            min_y, max_y = min(all_y), max(all_y)
            
            cx = (min_x + max_x) / 2
            cy = (min_y + max_y) / 2
            
            ref_x = np.array([min_x, max_x])
            
            # y - cy = k * (x - cx)  =>  y = k(x-cx) + cy
            ax.plot(ref_x, 1.0 * (ref_x - cx) + cy, 'k--', lw=2, alpha=1, label='Slope=1 (Ref)')
            ax.plot(ref_x, 2.0 * (ref_x - cx) + cy, 'k:', lw=2, alpha=1, label='Slope=2 (Ref)')
            
            ax.set_xlim(min_x - 0.5, max_x + 0.5)
            ax.set_ylim(min_y - 0.5, max_y + 0.5)

        if "covar" in key_a.lower():
            pair_name = "Empirical Covariance"
        elif "h" in key_a.lower():
            pair_name = "AWD-derived Covariance"
        ax.set_title(pair_name, fontweight='bold', pad=10)
        ax.set_xlabel(f"Centered $\log_{{10}}$({key_b})")
        
        if idx == 0:
            ax.set_ylabel(f"Shifted $\log_{{10}}$({key_a})")
            
        ax.legend(loc='lower right', frameon=True, framealpha=0.9, fontsize=14)
        ax.grid(True, linestyle='--', alpha=1)
        
        ax.set_aspect('equal', adjustable='datalim')
    
        # plt.plot(np.diag(R_real), color='black', linewidth=0.5) 
        # plt.suptitle(f"Deep Dive Snapshot: {name}", fontsize=16)

    save_dir = f"ICML_Figures"
    os.makedirs(save_dir, exist_ok=True)
    filename = f"loglog.pdf"
    save_path = os.path.join(save_dir, filename)
    
    plt.savefig(
        save_path, 
        format='pdf',         
        bbox_inches='tight',  
        pad_inches=0.05,       
        dpi=300                
    )
    
    print(f"Saved vector figure to: {save_path}")
    
    plt.show() #

    plt.tight_layout()
    plt.show()

# ============================
# üîß Main Execution
# ============================
if __name__ == "__main__":
    

    epoch_to_plot = 100
    slice_range = torch.arange(0, 1500) 
    

    var_pairs_to_plot = {
        "Covar vs Hessian": ("Covar", "Hessian"),
        "H_2 vs Hessian": ("H_2_d", "Hessian"), 
        # "Covar vs H_2": ("Covar", "H_2_d"),
        # "H_1 vs Hessian": ("H_1_d", "Hessian"),
    }

    configs = [
        {
            'label': 'CNN/CIFAR10/CSE',
            'lss_fn': 'cse',
            'dataset': 'cifar10',
            'model': 'CNN',
            'B': 128,          
            'alpha': 0.1,      
            'train_size': 2000,
            'sample_number': 20,
            'class_number': 10,
            'color': '#1f77b4' 
        },
        {
            'label': 'MLP/MNIST/CSE',
            'lss_fn': 'cse',
            'dataset': 'mnist',
            'model': 'FC',
            'B': 50,          
            'alpha': 0.1,
            'train_size': 2000,
            'sample_number': 20,
            'class_number': 10,
            'color': '#d62728'
        },
        {
            'label': 'MLP/MNIST/MSE',
            'lss_fn': 'mse',
            'dataset': 'mnist',
            'model': 'FC',
            'B': 50,           
            'alpha': 0.1,
            'train_size': 2000,
            'sample_number': 20,
            'class_number': 10,
            'color': '#2ca02c'
        },
    ]

    print(f"--- Starting Analysis for Epoch {epoch_to_plot} ---")
    
    results_list = []
    for cfg in configs:
        print(f"Processing: {cfg['label']}...")
        data = load_and_process_data(cfg, epoch_to_plot, slice_range, var_pairs_to_plot)
        results_list.append(data)


    if any(results_list):
        plot_shifted_loglog(results_list, var_pairs_to_plot, y_shift_step=1.5)
    else:
        print("‚ùå No data loaded. Please check your folder paths and naming convention.")

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import os


def set_publication_style():
    plt.rcParams.update({
        'font.family': 'serif',
        'font.serif': ['Times New Roman', 'Times', 'DejaVu Serif'],
        'mathtext.fontset': 'stix',
        'font.size': 20,
        'axes.labelsize': 20,
        'axes.titlesize': 22,
        'xtick.labelsize': 20,
        'ytick.labelsize': 20,
        'legend.fontsize': 12,
        'figure.dpi': 300,
        'savefig.dpi': 300,
        'axes.linewidth': 2,
        'lines.linewidth': 2,
        'xtick.major.width': 2,
        'ytick.major.width': 2,
        'xtick.direction': 'in',
        'ytick.direction': 'in',
    })
set_publication_style()

def get_rotated_diagonal_decay(target_tensor, basis_tensor=None, top_k=1500, exclude_first=20):
    
    if basis_tensor is None:
        try:
            vals = torch.linalg.eigvalsh(target_tensor.double())
        except Exception as e:
            print(f"‚ö†Ô∏è Eigen decomposition failed: {e}")
            return None, None, None, None
    else:
        try:
            # Hessian = V * L * V.T
            L_h, V_h = torch.linalg.eigh(basis_tensor.double())
            
            target_double = target_tensor.double()
            # Diag(V^T C V)
            rotated_matrix = V_h.T @ target_double @ V_h
            vals = torch.diag(rotated_matrix)
        except Exception as e:
            print(f"‚ö†Ô∏è Basis decomposition failed: {e}")
            return None, None, None, None

    vals = vals.detach().cpu().numpy()

    vals = np.sort(vals)[::-1]
    
    if top_k is not None and top_k < len(vals):
        vals = vals[:top_k]
    
    ranks = np.arange(1, len(vals) + 1)
    
    mask = vals > 1e-20
    if mask.sum() < exclude_first + 10:
        return None, None, None, None

    valid_vals = vals[mask]
    valid_ranks = ranks[mask]

    log_ranks = np.log10(valid_ranks)
    log_vals = np.log10(valid_vals)


    log_vals_centered = log_vals - np.mean(log_vals)

    if len(log_ranks) > exclude_first:
        fit_x = log_ranks[exclude_first:800]
        fit_y = log_vals_centered[exclude_first:800] 
    else:
        fit_x = log_ranks
        fit_y = log_vals_centered

    slope, intercept, r_value, _, _ = stats.linregress(fit_x, fit_y)
    r2 = r_value**2
    
    fit_line_y = slope * fit_x + intercept

    return {
        'all_x': log_ranks,
        'all_y': log_vals_centered,  
        'fit_y_pred': fit_line_y,
        'slope': slope,
        'r2': r2,
        'exclude_idx': exclude_first
    }

def load_and_analyze(config, epoch, analysis_pairs, top_k=1500, exclude_first=20):
    
    train_size = config.get('train_size')
    sample_number = config.get('sample_number')
    class_num = config.get('class_number')
    
    save_dir = (f"./AWCH_data/TrainSize{train_size}_SampleN{sample_number}_"
                f"ClassN{class_num}_B{config['B']}lr{config['alpha']}_"
                f"lossfn_{config['lss_fn']}_model_{config['model']}_dataset_{config['dataset']}")

    base_path = os.path.join(save_dir, "C_epoch_") 
    file_path = f"{base_path}{epoch}.pt"

    if not os.path.exists(file_path):
        print(f"‚ö†Ô∏è Path not found: {save_dir}")
        return {}

    try:
        loaded_data = torch.load(file_path, map_location='cpu')
    except Exception as e:
        print(f"‚ùå Error Loading: {e}")
        return {}

    processed_results = {}

    for plot_name, (target_key, basis_key) in analysis_pairs.items():
        if target_key not in loaded_data: continue
        
        target_tensor = torch.tensor(loaded_data[target_key])
        basis_tensor = None
        if basis_key is not None:
            if basis_key in loaded_data:
                basis_tensor = torch.tensor(loaded_data[basis_key])
            else:
                continue
        
        res = get_rotated_diagonal_decay(
            target_tensor, basis_tensor, top_k=top_k, exclude_first=exclude_first
        )
        
        if res is not None:
            res['label'] = config['label']
            res['color'] = config.get('color', 'black')
            processed_results[plot_name] = res

    del loaded_data
    return processed_results


def plot_centered_comparison(results_list, analysis_pairs, y_shift_step=2.0, top_k=1500):
    
    num_plots = len(analysis_pairs)
    

    fig, axes = plt.subplots(1, num_plots, figsize=(6 * num_plots, 7), 
                             sharey=True, sharex=True)
    
    if num_plots == 1: axes = [axes]


    plt.subplots_adjust(wspace=0.05)

    for idx, (plot_name, _) in enumerate(analysis_pairs.items()):
        ax = axes[idx]
        

        for config_idx, res_dict in enumerate(results_list):
            if plot_name not in res_dict: continue
            
            data = res_dict[plot_name]
            
            all_x = data['all_x']
            all_y = data['all_y'] 
            fit_x = data['fit_x']
            fit_y_pred = data['fit_y_pred']
            slope = data['slope']
            color = data['color']
            label = data['label']

            shift = config_idx * y_shift_step
            all_y_shifted = all_y - shift
            fit_y_pred_shifted = fit_y_pred - shift
            
            label_txt = f"{label} "
            ax.scatter(all_x, all_y_shifted, s=10, color=color, alpha=0.5, edgecolors='none',label=label_txt)
            
            # alpha_val = -slope
            
            # ax.plot(fit_x, fit_y_pred_shifted, '--', lw=2.5, color=color, label=label_txt)

        ax.set_title(plot_name, fontweight='bold', pad=12)
        ax.set_xlabel(r"$\log_{10}(\text{Rank})$")
        
        if idx == 0:
            ax.set_ylabel(r"Centered $\log_{10}(\text{Value})$ (Shifted)")

        ax.legend(loc='lower left', frameon=True, fontsize=18, framealpha=1)
        ax.grid(True, linestyle='--', alpha=1)

    save_path = f"ICML_Figures/centered_decay_comparison_top{top_k}.pdf"
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, format='pdf', bbox_inches='tight', pad_inches=0.05, dpi=300)
    print(f"Saved: {save_path}")
    plt.show()


if __name__ == "__main__":
    
    epoch_to_plot = 100
    top_k = 1000
    exclude_outliers = 20  


    analysis_pairs = {
        "Hessian Spectra": ("Hessian", None), 
        "Covar Diagonals (in Hessian Basis)": ("Covar", "Hessian"),
    }


    configs = [
        {
            'label': 'CNN/CIFAR10/CSE',
            'lss_fn': 'cse', 'dataset': 'cifar10', 'model': 'CNN',
            'B': 128, 'alpha': 0.1, 'train_size': 2000, 
            'sample_number': 20, 'class_number': 10, 'color': '#1f77b4' # Blue
        },
        {
            'label': 'MLP/MNIST/CSE',
            'lss_fn': 'cse', 'dataset': 'mnist', 'model': 'FC',
            'B': 50, 'alpha': 0.1, 'train_size': 2000, 
            'sample_number': 20, 'class_number': 10, 'color': '#d62728' # Red
        },
        {
            'label': 'MLP/MNIST/MSE',
            'lss_fn': 'mse', 'dataset': 'mnist', 'model': 'FC',
            'B': 50, 'alpha': 0.1, 'train_size': 2000, 
            'sample_number': 20, 'class_number': 10, 'color': '#2ca02c' # Green
        },
    ]

    print(f"--- Processing Epoch {epoch_to_plot} ---")
    
    results = []
    for cfg in configs:
        print(f"Loading {cfg['label']}...")
        results.append(load_and_analyze(cfg, epoch_to_plot, analysis_pairs, top_k, exclude_outliers))

    if any(results):

        plot_centered_comparison(results, analysis_pairs, y_shift_step=2.5, top_k=top_k)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os


def set_publication_style():
    plt.rcParams.update({
        'font.family': 'serif',          
        'font.serif': ['Times New Roman', 'Times', 'DejaVu Serif'],
        'mathtext.fontset': 'stix',     
        'font.size': 20,                
        'axes.labelsize': 20,            
        'axes.titlesize': 22,            
        'xtick.labelsize': 20,          
        'ytick.labelsize': 20,          
        'legend.fontsize': 12,          
        'figure.dpi': 300,             
        'savefig.dpi': 300,             
        'axes.linewidth': 2,          
        'lines.linewidth': 2,         
        'xtick.major.width': 2,    
        'ytick.major.width': 2,
        'xtick.direction': 'in',         
        'ytick.direction': 'in',
    })
set_publication_style()

def compute_frobenius_norm(matrix, slice_range=None):
    return torch.norm(torch.tensor(matrix), p='fro').item()

def get_diag_elements(matrix, slice_range=None):

    t = torch.tensor(matrix)
    if t.ndim == 1:
        diag = t
    else:
        diag = torch.diag(t)
    

    diag = torch.flip(diag, dims=[0]) 

    if slice_range is not None:
        slice_idx = slice_range.long()
        valid_mask = slice_idx < len(diag)
        valid_idx = slice_idx[valid_mask]
        return diag[valid_idx].cpu().numpy(), valid_idx.cpu().numpy()
    
    return diag.cpu().numpy(), np.arange(len(diag))

def analyze_data(epoch_list, base_path, matrix_keys, slice_range=None):

    norm_results = {key: [] for key in matrix_keys}
    last_epoch_data = None
    
    for epoch in epoch_list:
        file_path = f"{base_path}{epoch}.pt"
        if not os.path.exists(file_path):
            continue
        
        try:
            loaded_data = torch.load(file_path, map_location='cpu')
            
            if epoch == epoch_list[-1]:
                last_epoch_data = loaded_data

            for key in matrix_keys:
                if key in loaded_data:
                    fnorm = compute_frobenius_norm(loaded_data[key], slice_range=slice_range)
                    norm_results[key].append((epoch, fnorm))
            
            if epoch != epoch_list[-1]:
                del loaded_data

        except Exception as e:
            print(f"‚ùå Error Loading Epoch {epoch}: {e}")
    for key in norm_results:
        norm_results[key] = sorted(norm_results[key], key=lambda x: x[0])

    loss_data = []
    acc_data = []
    diag_data = {} 

    if last_epoch_data:
        # Loss & Acc
        if 'train_loss_holder' in last_epoch_data: raw_loss = last_epoch_data['train_loss_holder']
        elif 'loss' in last_epoch_data: raw_loss = last_epoch_data['loss']
        else: raw_loss = []
        
        if 'train_accuracy_holder' in last_epoch_data: raw_acc = last_epoch_data['train_accuracy_holder']
        elif 'acc' in last_epoch_data: raw_acc = last_epoch_data['acc']
        else: raw_acc = []

        if isinstance(raw_loss, torch.Tensor): loss_data = raw_loss.cpu().numpy().tolist()
        elif isinstance(raw_loss, list): loss_data = raw_loss
        
        if isinstance(raw_acc, torch.Tensor): acc_data = raw_acc.cpu().numpy().tolist()
        elif isinstance(raw_acc, list): acc_data = raw_acc
        
        # Diagonals
        target_keys = ["C", "C1", "C2", "C3", "C1_dia", "C1_dia_w_dia"]
        for key in target_keys:
            if key in last_epoch_data:
                values, indices = get_diag_elements(last_epoch_data[key], slice_range)
                diag_data[key] = (values, indices)
        
        del last_epoch_data 

    return norm_results, loss_data, acc_data, diag_data


def plot_results(norm_results, loss_data, acc_data, diag_data, config):
    
    colors = plt.cm.tab10.colors 
    
    fig, axes = plt.subplots(2, 2, figsize=(12, 9))
    
    ax_norm = axes[0, 0]    # 1. Top-Left
    ax_scatter1 = axes[0, 1]# 2. Top-Right
    ax_loss = axes[1, 0]    # 3. Bottom-Left
    ax_scatter2 = axes[1, 1]# 4. Bottom-Right
    
    # ==========================================
    # 1. Frobenius Norm Evolution (Top-Left)
    # ==========================================
    idx = 0
    for name, data in norm_results.items():
        if not data: continue
        epochs = [x[0] for x in data]
        norms = [x[1] for x in data]
        if "C" == name:
            dname = "$\\mathbf{C}_{AWD,raw}$"
        elif "Covar" == name:
            dname = "$\\mathbf{Covar}$"
        elif "C1" == name:
            dname = "$\\mathbf{C}^{hh}$"
        elif "C1_dia" ==   name:
            dname = "$\\mathbf{C}^{hh,SD}$"
        elif "C1_dia_w_dia" == name:
            dname = "$\\mathbf{C}^{hh,SD,WD}$"
        elif "H2" == name:
            dname = "$2\\mathbf{C}/\\sigma_w^2$"
        elif "C2" == name:
            dname = "$\\mathbf{C}^{hg}$"
        elif "C3" == name:
            dname = "$\\mathbf{C}^{gg}$"  
        ax_norm.semilogy(epochs, norms, marker='o', markersize=4, 
                         label=dname, color=colors[idx % len(colors)])
        idx += 1
    
    ax_norm.set_title("Frobenius Norm Evolution")
    ax_norm.set_xlabel("Epoch")
    ax_norm.set_ylabel(r"$\|\mathbf{M}\|_F$")
    ax_norm.legend(loc='lower center', frameon=False, fancybox=False, edgecolor='black', fontsize=16)
    ax_norm.grid(True, which="minor", linestyle=':', alpha=0.2)

    # ==========================================
    # 3. Loss & Accuracy (Bottom-Left)
    # ==========================================
    if loss_data or acc_data:
        epochs_range = range(1, len(loss_data) + 1) if loss_data else range(1, len(acc_data) + 1)
        
        color_loss = '#D62728' 
        ax_loss.set_xlabel('Epoch')
        ax_loss.set_ylabel('Training Loss', color=color_loss, fontweight='bold')
        l1 = None
        if loss_data:
            l1, = ax_loss.semilogy(epochs_range, loss_data, color=color_loss, linewidth=1, alpha=0.9, label='Loss')
        ax_loss.tick_params(axis='y', labelcolor=color_loss)
        ax_loss.grid(True, which="minor", linestyle=':', alpha=0.2)

        ax_acc = ax_loss.twinx()
        color_acc = '#1F77B4' 
        ax_acc.set_ylabel('Training Accuracy', color=color_acc, fontweight='bold')
        l2 = None
        if acc_data:
            curr_range = range(1, len(acc_data) + 1)
            l2, = ax_acc.plot(curr_range, acc_data, color=color_acc, linewidth=1, alpha=0.9, label='Acc')
        ax_acc.tick_params(axis='y', labelcolor=color_acc)
        
        lines, labels = [], []
        if l1: lines.append(l1); labels.append("Loss")
        if l2: lines.append(l2); labels.append("Accuracy")
        if lines:
            ax_loss.legend(lines, labels, loc='center right', frameon=False, fancybox=False, edgecolor='black', fontsize=16)
        ax_loss.set_title("Training Loss & Accuracy")
        ax_loss.grid(True, which="minor", linestyle=':', alpha=0.2)


    
    def plot_components_vs_c(ax):
        if "C" not in diag_data: return
        c_vals, indices = diag_data["C"]
        c_denom = c_vals 
        
        keys = ["C1", "C2", "C3"]
        local_idx = 0
        for key in keys:
            if key not in diag_data: continue
            vals, _ = diag_data[key]
            
            ratio = vals / c_denom

            if "C1" in key:
                dname = "$C_{ii}^{hh}/C_{AWD,raw,ii}$"
            elif "C2" in key:
                dname = "$C_{ii}^{hg}/C_{AWD,raw,ii}$"
            elif "C3" in key:
                dname = "$C_{ii}^{gg}/C_{AWD,raw,ii}$"

            ax.scatter(indices, ratio, s=10, alpha=0.8, 
                       label=dname, color=colors[local_idx % 10])
            local_idx += 1
        
        ax.set_title(r"Components Magnitude at epoch=100")
        ax.set_xlabel("Ascending Index")
        ax.set_ylabel("Ratio value")
        ax.legend(loc='center left', frameon=False, fancybox=False, edgecolor='black', fontsize=16)
        ax.grid(True, linestyle='--', alpha=0.3)

    def plot_approx_check(ax):
        if "C1" not in diag_data: return
        c_vals, indices = diag_data["C1"]
        c_denom = c_vals 


        # 2. C1_dia / C
        if "C1_dia" in diag_data:
            c1_dia_vals, _ = diag_data["C1_dia"]
            ratio_c1_dia = c1_dia_vals / c_denom
            ax.scatter(indices, ratio_c1_dia, s=10, alpha=0.8, 
                       label="$C_{ii}^{hh,SD}/C_{ii}^{hh}$", color=colors[0]) 

            if "C1_dia_w_dia" in diag_data:
                c1_dia_w_dia_vals, _ = diag_data["C1_dia_w_dia"]
                
                ratio_inter = c1_dia_vals / (c1_dia_w_dia_vals )
                
                ax.scatter(indices, ratio_inter, s=10, alpha=0.8, marker='^',
                           label="$C_{ii}^{hh,SD}/C_{ii}^{hh,SD,WD}$", color=colors[2]) 

        ax.set_title(r"Approximations Check at epoch=100")
        ax.set_xlabel("Ascending Index")
        ax.set_ylabel("Ratio value")
        ax.legend(loc='lower left', frameon=False, fancybox=False, edgecolor='black', fontsize=16)
        ax.grid(True, linestyle='--', alpha=0.3)
        ax.set_ylim(0, 1.5)

    plot_components_vs_c(ax_scatter1)
    plot_approx_check(ax_scatter2)

    plt.tight_layout()
    save_dir = f"ICML_Figures/{config['model']}_{config['dataset']}_{config['lss_fn']}"
    os.makedirs(save_dir, exist_ok=True)
    filename = f"C_approx_com.pdf"
    save_path = os.path.join(save_dir, filename)
    
    plt.savefig(
        save_path, 
        format='pdf',         
        bbox_inches='tight',   
        pad_inches=0.05,       
        dpi=100               
    )

    plt.show()



if __name__ == "__main__":
    epoch_list = [1, 10, 30, 50, 70, 80, 90, 100]
    slice_range = torch.arange(0, 1500) 
    
    train_size = 2000
    sample_number = 20
    config = {}
    config['lss_fn'] = 'mse'
    config['dataset'] = 'mnist'
    config['model'] = 'FC'
    config['sample_holder'] = [0,1,2,3,4,5,6,7,8,9]
    config['class_number'] = 10
    config['B'] = 50
    config['alpha'] = 0.1
    
    save_dir = f"./AWCH_data/TrainSize{train_size}_SampleN{sample_number}_ClassN{len(config['sample_holder'])}_B{config['B']}lr{config['alpha']}_lossfn_{config['lss_fn']}_model_{config['model']}_dataset_{config['dataset']}"
    file_name = "C_epoch_"
    base_path = os.path.join(save_dir, file_name)
    
    matrix_keys = ["C", "C1", "C2", "C3", "C1_dia", "C1_dia_w_dia"]

    norm_results, loss_data, acc_data, diag_data = analyze_data(epoch_list, base_path, matrix_keys, slice_range=slice_range)
    
    plot_results(norm_results, loss_data, acc_data, diag_data, config)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import os
import matplotlib.cm as cm
import math

from zmq import NORM_SEGMENT_SIZE


def apply_icml_style():
    plt.rcParams.update({
        'font.family': 'serif',
        'font.serif': ['Times New Roman', 'Times', 'DejaVu Serif'],
        'mathtext.fontset': 'stix',
        'font.size': 13,
        'axes.labelsize': 15,
        'axes.titlesize': 13,
        'xtick.labelsize': 11,
        'ytick.labelsize': 11,
        'legend.fontsize': 9,
        'lines.linewidth': 0.5,
        'axes.grid': True,
        'grid.alpha': 0.3,
        'grid.linestyle': '--',
        'figure.dpi': 300,
        'savefig.bbox': 'tight',
    })


def extract_diagonals(A, B, slice_range=None):
    if slice_range is None:
        slice_range = torch.arange(min(A.shape[0], B.shape[0]))
    A = torch.flip(A, dims=[0, 1])
    B = torch.flip(B, dims=[0, 1])
    diag_a_val = torch.diag(A)[slice_range]
    diag_b_val = torch.diag(B)[slice_range]
    return diag_a_val, diag_b_val

def load_and_merge_data(epoch, base_path1, base_path2):
    file_path1 = f"{base_path1}{epoch}.pt"
    file_path2 = f"{base_path2}{epoch}.pt"
    
    if not os.path.exists(file_path1) or not os.path.exists(file_path2):
        pass 

    try:
        if os.path.exists(file_path1):
            d1 = torch.load(file_path1, map_location='cpu')
        else:
            d1 = {}
        
        if os.path.exists(file_path2):
            d2 = torch.load(file_path2, map_location='cpu')
        else:
            d2 = {}
            
        data = {**d1, **d2}
        return data
    except Exception as e:
        print(f"‚ùå Error Loading Data: {e}")
        return None

def compute_loglog_fit_simple(x, y):
    slope, intercept, r_value, _, _ = stats.linregress(x, y)
    return slope, intercept, r_value**2


def clean_tex_label(text):

    return text.replace("_", r"\_")
    

def _plot_scatter_on_ax(ax, data_dict, hessian_eig, var_pair, slice_range):

    key_a, key_b = var_pair
    
    if key_a not in data_dict or key_b not in data_dict:
        ax.text(0.5, 0.5, "Missing Data", ha='center', va='center')
        return

    A = torch.tensor(data_dict[key_a])
    B = torch.tensor(data_dict[key_b])
    
    L, V = hessian_eig
    if "covar" in key_a.lower() or "hessian" in key_a.lower():
        A = V.transpose(-2, -1) @ A @ V
    if "covar" in key_b.lower() or "hessian" in key_b.lower():
        B = V.transpose(-2, -1) @ B @ V
        
    diag_a, diag_b = extract_diagonals(A, B, slice_range)
    
    da = diag_a.detach().cpu().numpy()
    db = diag_b.detach().cpu().numpy()
    mask = (da > 0) & (db > 0)
    
    if mask.sum() < 5:
        ax.text(0.5, 0.5, "Invalid Data", ha='center', va='center')
        return
        
    raw_log_a = np.log10(da[mask])
    raw_log_b = np.log10(db[mask])
    
    x_data = raw_log_b - raw_log_b.mean()
    y_data = raw_log_a - raw_log_a.mean()
    
    slope, intercept, r2 = compute_loglog_fit_simple(x_data, y_data)
    
    ax.scatter(x_data, y_data, alpha=0.3, s=6, c='#1f77b4', edgecolors='none', rasterized=True)
    
    max_val = max(np.max(np.abs(x_data)), np.max(np.abs(y_data)))
    limit = max_val * 1.2
    x_ref = np.array([-limit, limit])
    
    ax.plot(x_ref, 1.0 * x_ref, color='#d62728', linestyle='--', lw=1.2, alpha=0.8, label='Slope=1')
    ax.plot(x_ref, 2.0 * x_ref, color='#2ca02c', linestyle='-.', lw=1.2, alpha=0.8, label='Slope=2')
    
    y_fit = slope * x_ref + intercept
    ax.plot(x_ref, y_fit, 'k-', lw=1.0, alpha=0.7, label='Fit')

    ax.set_aspect('equal')
    ax.set_xlim([-limit, limit])
    ax.set_ylim([-limit, limit])
    
    clean_a = clean_tex_label(key_a)
    clean_b = clean_tex_label(key_b)
    
    title_str = fr"$\mathbf{{{clean_a}}}$ vs $\mathbf{{{clean_b}}}$"
    ax.set_title(title_str, fontsize=10)
    
    info_text = f"$\\alpha = {slope:.2f}$\n$R^2 = {r2:.2f}$"
    ax.text(0.05, 0.95, info_text, transform=ax.transAxes, 
            fontsize=9, verticalalignment='top', 
            bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="black", alpha=0.9))

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.axhline(0, color='gray', linewidth=0.5, alpha=0.3)
    ax.axvline(0, color='gray', linewidth=0.5, alpha=0.3)



def plot_all_scatters_in_grid(epoch, base_path1, base_path2, slice_range, pair_list, config):
    print(f"Processing Epoch {epoch} for Centered Grid Plot...")
    
    data = load_and_merge_data(epoch, base_path1, base_path2)
    if data is None: return


    H = torch.tensor(data["Hessian"]).float()
    if not torch.isfinite(H).all():
            H = torch.nan_to_num(H)
    hessian_eig = torch.linalg.eigh(H)
    # else:
    #     # Fallback
    #     first_key = list(data.keys())[0]
    #     dim = data[first_key].shape[0]
    #     hessian_eig = (torch.ones(dim), torch.eye(dim))

    n_plots = len(pair_list)
    cols = 4 
    rows = math.ceil(n_plots / cols)
    
    fig, axes = plt.subplots(rows, cols, figsize=(15, 3.5 * rows))
    axes = axes.flatten()

    for i, pair in enumerate(pair_list):
        _plot_scatter_on_ax(axes[i], data, hessian_eig, pair, slice_range)
        
        row_idx = i // cols
        col_idx = i % cols
        if row_idx == rows - 1:
            axes[i].set_xlabel(r"Centered $\log_{10}(X_{ii})$")
        if col_idx == 0:
            axes[i].set_ylabel(r"Centered $\log_{10}(Y_{ii})$")

    for j in range(i + 1, len(axes)):
        axes[j].axis('off')

    handles, labels = axes[0].get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    fig.legend(by_label.values(), by_label.keys(), loc='upper center', 
               bbox_to_anchor=(0.5, 1.01), ncol=4, frameon=False, fontsize=11)

    fig.suptitle(f"Centered Log-Log Analysis @ Epoch {epoch}", fontsize=16, y=1.03)
    plt.tight_layout()
    
    save_dir = f"ICML_Figures/{config['model']}_{config['dataset']}_{config['lss_fn']}"
    os.makedirs(save_dir, exist_ok=True)
    filename = f"Centered_LogLog.pdf"
    save_path = os.path.join(save_dir, filename)
    
    plt.savefig(
        save_path, 
        format='pdf',         
        bbox_inches='tight',   
        pad_inches=0.05,       
        dpi=300                
    )    
    # plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"Saved Grid Plot to {save_path}")


    plt.show()
    
    del data, hessian_eig


def analyze_epochs(epoch_list, base_path1, base_path2, slice_range, var_pairs):
    results = {name: [] for name in var_pairs.keys()}
    print("Starting Slope vs Epoch Analysis...")
    
    for epoch in epoch_list:
        data = load_and_merge_data(epoch, base_path1, base_path2)
        if data is None: continue

        if "Hessian" in data:
            try:
                L, V = torch.linalg.eigh(torch.tensor(data["Hessian"]).float())
            except:
                continue
        else:
            continue

        for name, (key_a, key_b) in var_pairs.items():
            if key_a not in data or key_b not in data: continue
            
            A = torch.tensor(data[key_a])
            B = torch.tensor(data[key_b])
            
            if "covar" in key_a.lower() or "hessian" in key_a.lower(): A = V.transpose(-2, -1) @ A @ V
            if "covar" in key_b.lower() or "hessian" in key_b.lower(): B = V.transpose(-2, -1) @ B @ V

            diag_a, diag_b = extract_diagonals(A, B, slice_range)
            
            da = diag_a.detach().cpu().numpy()
            db = diag_b.detach().cpu().numpy()
            mask = (da > 0) & (db > 0)
            if mask.sum() > 5:
                log_a, log_b = np.log10(da[mask]), np.log10(db[mask])
                slope, _, r_val, _, _ = stats.linregress(log_b, log_a)
                results[name].append((epoch, slope, r_val**2))
                
        del data, L, V, A, B
    return results

def plot_slope_vs_epoch_multi(results, config):
    if not any(results.values()):
        print("No slope data to plot.")
        return

    plt.figure(figsize=(10, 6))
    colors = cm.get_cmap('tab10')
    
    for i, (name, data) in enumerate(results.items()):
        if not data: continue
        data = sorted(data, key=lambda x: x[0])
        epochs = [x[0] for x in data]
        slopes = [x[1] for x in data]
        plt.plot(epochs, slopes, marker='o', markersize=5, linewidth=2, 
                 label=name, color=colors(i), alpha=0.8)
    
    plt.axhline(y=1, color='gray', linestyle='--', linewidth=1.5, alpha=0.6, label="Slope=1")
    plt.axhline(y=2, color='gray', linestyle=':', linewidth=1.5, alpha=0.6, label="Slope=2")
    
    plt.title("Evolution of Power-Law Exponent (Slope) over Training")
    plt.xlabel("Training Epoch")
    plt.ylabel(r"Slope $\alpha$")
    plt.ylim([0.5, 2.5])
    plt.legend(loc='center right', bbox_to_anchor=(1.15, 0.5), frameon=True)
    plt.grid(True, linestyle='--', alpha=0.4)
    plt.tight_layout()
    


    plt.show()



if __name__ == "__main__":
    apply_icml_style()

    epoch_list = [1, 10, 30, 50, 70, 80, 90, 100]
    slice_range = torch.arange(0, 1500)
    Net_size =50
    train_size = 2000
    config = {'B': 128, 'alpha': 0.1, 'lss_fn': 'cse', 'model': 'CNN', 'dataset': 'cifar10'}
    n_class = 10
    # save_dir = f"./AWCH_data/NS{Net_size}_TrainSize{train_size}_SampleN{20}_ClassN{n_class}_B{config['B']}lr{config['alpha']}_lossfn_{config['lss_fn']}_model_{config['model']}_dataset_{config['dataset']}"
    save_dir = f"./AWCH_data/TrainSize{train_size}_SampleN{20}_ClassN{n_class}_B{config['B']}lr{config['alpha']}_lossfn_{config['lss_fn']}_model_{config['model']}_dataset_{config['dataset']}"

    base_path1 = os.path.join(save_dir, "C_epoch_")
    base_path2 = os.path.join(save_dir, "C_epoch_") 


    var_pairs_for_evolution = {
        "C vs Hessian": ("C", "Hessian"),
        "C1 vs Hessian": ("C1", "Hessian"),
        "Covar vs Hessian": ("Covar", "Hessian"),
        "H_2_d vs Hessian": ("H_2_d", "Hessian"),
    }
    
    results = analyze_epochs(epoch_list, base_path1, base_path2, slice_range, var_pairs_for_evolution)
    plot_slope_vs_epoch_multi(results, config)

    pairs_to_plot_grid = [
        ("C", "Hessian"),
        ("C1", "Hessian"),
        ("C1_dia", "Hessian"),
        ("C1_dia_w_dia", "Hessian"),
        ("H_2_d", "Hessian"),
        ("Covar", "Hessian"),
        ("H_2_d", "Covar"),
        ("C", "Covar"),
        ("C", "C1"),
        ("C1", "C1_dia"),
        ("C1_dia", "C1_dia_w_dia"),
        ("H_2_d", "C1_dia_w_dia")
    ]
    
    target_epoch = 100
    plot_all_scatters_in_grid(target_epoch, base_path1, base_path2, slice_range, pairs_to_plot_grid, config)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import os
import matplotlib.cm as cm
import math


def set_publication_style():
    """ÈÖçÁΩÆ Matplotlib ‰ª•Á¨¶Âêà ICML/NeurIPS ËÆ∫ÊñáÊ†áÂáÜ"""
    plt.rcParams.update({
        'font.family': 'serif',          
        'font.serif': ['Times New Roman', 'Times', 'DejaVu Serif'],
        'mathtext.fontset': 'stix',      
        'font.size': 20,                
        'axes.labelsize': 20,            
        'axes.titlesize': 22,           
        'xtick.labelsize': 20,           
        'ytick.labelsize': 20,          
        'legend.fontsize': 12,          
        'figure.dpi': 300,              
        'savefig.dpi': 300,            
        'axes.linewidth': 2,          
        'lines.linewidth': 2,            
        'xtick.major.width': 2,     
        'ytick.major.width': 2,
        'xtick.direction': 'in',         
        'ytick.direction': 'in',
    })
set_publication_style()


def clean_tex_label(text):

    return text.replace("_", r"\_")

def compute_loglog_fit_simple(x, y):
    slope, intercept, r_value, _, _ = stats.linregress(x, y)
    return slope, intercept, r_value**2

def extract_diagonals(A, B, slice_range=None):
    if slice_range is None:
        slice_range = torch.arange(min(A.shape[0], B.shape[0]))
    A = torch.flip(A, dims=[0, 1])
    B = torch.flip(B, dims=[0, 1])
    diag_a_val = torch.diag(A)[slice_range]
    diag_b_val = torch.diag(B)[slice_range]
    return diag_a_val, diag_b_val

def load_and_merge_data(epoch, base_path1, base_path2):
    file_path1 = f"{base_path1}{epoch}.pt"
    file_path2 = f"{base_path2}{epoch}.pt"
    if not os.path.exists(file_path1) or not os.path.exists(file_path2): return None
    try:
        d1 = torch.load(file_path1, map_location='cpu') if os.path.exists(file_path1) else {}
        d2 = torch.load(file_path2, map_location='cpu') if os.path.exists(file_path2) else {}
        return {**d1, **d2}
    except: return None



def _plot_scatter_on_ax(ax, data_dict, hessian_eig, var_pair, slice_range):

    key_a, key_b = var_pair
    
    if key_a not in data_dict or key_b not in data_dict:
        ax.text(0.5, 0.5, "Missing Data", ha='center', va='center')
        return

    A = torch.tensor(data_dict[key_a]).float()
    B = torch.tensor(data_dict[key_b]).float()
    
    L, V = hessian_eig
    if "covar" in key_a.lower() or "hessian" in key_a.lower():
        A = V.transpose(-2, -1) @ A @ V
    if "covar" in key_b.lower() or "hessian" in key_b.lower():
        B = V.transpose(-2, -1) @ B @ V
        
    diag_a, diag_b = extract_diagonals(A, B, slice_range)
    
    da = diag_a.detach().cpu().numpy()
    db = diag_b.detach().cpu().numpy()
    mask = (da > 0) & (db > 0)
    
    if mask.sum() < 5:
        ax.text(0.5, 0.5, "Invalid Data", ha='center', va='center')
        return
        
    raw_log_a = np.log10(da[mask])
    raw_log_b = np.log10(db[mask])
    
    x_data = raw_log_b - raw_log_b.mean()
    y_data = raw_log_a - raw_log_a.mean()
    
    slope, intercept, r2 = compute_loglog_fit_simple(x_data, y_data)
    

    ax.scatter(x_data, y_data, alpha=0.5, s=15, c='#1f77b4', edgecolors='none', rasterized=True)
    
    max_val = max(np.max(np.abs(x_data)), np.max(np.abs(y_data)))
    limit = max_val * 1.3 #
    x_ref = np.array([-limit, limit])
    
    ax.plot(x_ref, 1.0 * x_ref, color='#d62728', linestyle='--', lw=2.5, alpha=0.9, label='Slope=1')
    ax.plot(x_ref, 2.0 * x_ref, color='#2ca02c', linestyle='-.', lw=2.5, alpha=0.9, label='Slope=2')
    
    # 3. 
    y_fit = slope * x_ref + intercept
    ax.plot(x_ref, y_fit, 'k-', lw=2.0, alpha=0.8, label='Fit')


    ax.set_aspect('equal')
    ax.set_xlim([-limit, limit])
    ax.set_ylim([-limit, limit])
    if "C" == key_a:
        aname = "$\\mathbf{C}_{AWD,raw}$"
    elif "Covar" == key_a:
        aname = "$\\mathbf{Covar}$"
    elif "H_2_d" == key_a:
        aname = "$2\\mathbf{C}/\\sigma_w^2$"
    elif "C1" == key_a:
        aname = "$\\mathbf{C}^{hh}$"
    elif "C1_dia" ==   key_a:
        aname = "$\\mathbf{C}^{hh,SD}$"
    elif "C1_dia_w_dia" == key_a:
        aname = "$\\mathbf{C}^{hh,SD,WD}$"
    elif "H2" == key_a:
        aname = "$2\\mathbf{C}/\\sigma_w^2$"
    elif "C2" == key_a:
        aname = "$\\mathbf{C}^{hg}$"
    elif "C3" == key_a:
        aname = "$\\mathbf{C}^{gg}$" 
    if "C" == key_b:
        bname = "$\\mathbf{C}_{AWD,raw}$"
    elif "Covar" == key_b:
        bname = "$\\mathbf{Covar}$"
    elif "C1" == key_b:
        bname = "$\\mathbf{C}^{hh}$"
    elif "C1_dia" ==   key_b:
        bname = "$\\mathbf{C}^{hh,SD}$"
    elif "C1_dia_w_dia" == key_b:
        bname = "$\\mathbf{C}^{hh,SD,WD}$"
    elif "H2" == key_b:
        bname = "$2\\mathbf{C}/\\sigma_w^2$"
    elif "C2" == key_b:
        bname = "$\\mathbf{C}^{hg}$"
    elif "C3" == key_b:
        bname = "$\\mathbf{C}^{gg}$" 
    elif "Hessian" == key_b:
        bname = "$\\mathbf{H}$"
    title_str = fr"{aname} vs {bname}"
    ax.set_title(title_str, fontsize=30, pad=10)
    
    info_text = f"$\\alpha = {slope:.2f}$\n$R^2 = {r2:.2f}$"
    ax.text(0.05, 0.95, info_text, transform=ax.transAxes, 
            fontsize=30, verticalalignment='top', fontweight='bold',
            bbox=dict(boxstyle="square,pad=0.4", fc="white", ec="black", lw=1.5, alpha=0.9))


    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(2.0) 
        spine.set_color('black')


    ax.axhline(0, color='gray', linewidth=1.0, alpha=0.4)
    ax.axvline(0, color='gray', linewidth=1.0, alpha=0.4)

    ax.tick_params(axis='both', which='major', labelsize=30, width=2, length=6, direction='in')


def plot_all_scatters_in_grid(epoch, base_path1, base_path2, slice_range, pair_list, config):

    
    print(f"Processing Epoch {epoch} for Centered Grid Plot...")
    
    data = load_and_merge_data(epoch, base_path1, base_path2)
    if data is None: 
        print("Data not found.")
        return


    H = torch.tensor(data["Hessian"]).float()
    if not torch.isfinite(H).all(): H = torch.nan_to_num(H)
    hessian_eig = torch.linalg.eigh(H)

    n_plots = len(pair_list)
    cols = 4 
    rows = math.ceil(n_plots / cols)
    

    fig, axes = plt.subplots(rows, cols, figsize=(24, 6 * rows), constrained_layout=True)
    axes = axes.flatten()

    for i, pair in enumerate(pair_list):
        _plot_scatter_on_ax(axes[i], data, hessian_eig, pair, slice_range)
        
        row_idx = i // cols
        col_idx = i % cols
        
        if row_idx == rows - 1:
            axes[i].set_xlabel(r"Centered $\log_{10}(X_{ii})$", fontsize=30)
        
        if col_idx == 0:
            axes[i].set_ylabel(r"Centered $\log_{10}(Y_{ii})$", fontsize=30)
    for j in range(i + 1, len(axes)):
        axes[j].axis('off')

    handles, labels = axes[0].get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    

    fig.legend(by_label.values(), by_label.keys(), 
               loc='upper center', 
               bbox_to_anchor=(0.5, 1.05), 
               ncol=4, 
               frameon=True, 
               edgecolor='black',
               fancybox=False, 
               fontsize=18,    
               borderpad=0.8)

    # fig.suptitle(f"Appendix: Centered Log-Log Analysis (Epoch {epoch})", fontsize=24, fontweight='bold', y=1.05)
    
    save_dir = f"ICML_Figures/{config['model']}_{config['dataset']}_{config['lss_fn']}"
    os.makedirs(save_dir, exist_ok=True)
    filename = f"Appendix_Grid_Boxed.pdf"
    save_path = os.path.join(save_dir, filename)
    
    print(f"Saving high-res boxed figure to {save_path}...")
    plt.savefig(
        save_path, 
        format='pdf', 
        bbox_inches='tight', 
        pad_inches=0.1, 
        dpi=300
    )
    
    plt.show()
    del data, hessian_eig



if __name__ == "__main__":
    epoch_list = [100] 
    slice_range = torch.arange(0, 1500)
    

    train_size = 2000
    config = {'B': 50, 'alpha': 0.1, 'lss_fn': 'mse', 'model': 'FC', 'dataset': 'mnist'}
    n_class = 10

    save_dir = f"./AWCH_data/TrainSize{train_size}_SampleN{20}_ClassN{n_class}_B{config['B']}lr{config['alpha']}_lossfn_{config['lss_fn']}_model_{config['model']}_dataset_{config['dataset']}"
    base_path1 = os.path.join(save_dir, "C_epoch_")
    base_path2 = os.path.join(save_dir, "C_epoch_") 

    pairs_to_plot_grid = [
        ("C", "Hessian"),
        ("C1", "Hessian"),
        ("C1_dia", "Hessian"),
        ("C1_dia_w_dia", "Hessian"),
        ("H_2_d", "Hessian"),
        ("Covar", "Hessian"),
        ("H_2_d", "Covar"),
        ("C", "Covar"),
        ("C", "C1"),
        ("C1", "C1_dia"),
        ("C1_dia", "C1_dia_w_dia"),
        ("H_2_d", "C1_dia_w_dia")
    ]
    
    plot_all_scatters_in_grid(100, base_path1, base_path2, slice_range, pairs_to_plot_grid, config)

In [None]:
from pdb import run
import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import os
import matplotlib.cm as cm
import math

def apply_icml_appendix_style():
    plt.rcParams.update({
        'font.family': 'serif',
        'font.serif': ['Times New Roman', 'Times', 'DejaVu Serif'],
        'mathtext.fontset': 'stix',
        'font.size': 16,            
        'axes.labelsize': 18,
        'axes.titlesize': 18,
        'xtick.labelsize': 14,
        'ytick.labelsize': 14,
        'legend.fontsize': 14,
        'lines.linewidth': 2.0,    
        'axes.linewidth': 1.5,      
        'figure.dpi': 300,
        'savefig.bbox': 'tight',
    })


def clean_tex_label(text):
    return text.replace("_", r"\_")

def compute_loglog_fit_simple(x, y):
    slope, intercept, r_value, _, _ = stats.linregress(x, y)
    return slope, intercept, r_value**2

def extract_diagonals(A, B, slice_range=None):
    if slice_range is None:
        slice_range = torch.arange(min(A.shape[0], B.shape[0]))
    A = torch.flip(A, dims=[0, 1])
    B = torch.flip(B, dims=[0, 1])
    diag_a_val = torch.diag(A)[slice_range]
    diag_b_val = torch.diag(B)[slice_range]
    return diag_a_val, diag_b_val

def load_and_merge_data(epoch, base_path1, base_path2, run_idx=0):
    file_path1 = f"{base_path1}{epoch}_run_{run_idx}.pt"
    file_path2 = f"{base_path2}{epoch}_run_{run_idx}.pt"
    if not os.path.exists(file_path1) or not os.path.exists(file_path2): return None
    try:
        d1 = torch.load(file_path1, map_location='cpu') if os.path.exists(file_path1) else {}
        d2 = torch.load(file_path2, map_location='cpu') if os.path.exists(file_path2) else {}
        return {**d1, **d2}
    except: return None



def _plot_scatter_on_ax(ax, data_dict, hessian_eig, var_pair, slice_range):

    key_a, key_b = var_pair
    

    if key_a not in data_dict or key_b not in data_dict:
        ax.text(0.5, 0.5, "Missing Data", ha='center', va='center')
        return

    A = torch.tensor(data_dict[key_a]).float()
    B = torch.tensor(data_dict[key_b]).float()
    
    L, V = hessian_eig
    if "covar" in key_a.lower() or "hessian" in key_a.lower():
        A = V.transpose(-2, -1) @ A @ V
    if "covar" in key_b.lower() or "hessian" in key_b.lower():
        B = V.transpose(-2, -1) @ B @ V
        
    diag_a, diag_b = extract_diagonals(A, B, slice_range)
    
    da = diag_a.detach().cpu().numpy()
    db = diag_b.detach().cpu().numpy()
    mask = (da > 0) & (db > 0)
    
    if mask.sum() < 5:
        ax.text(0.5, 0.5, "Invalid Data", ha='center', va='center')
        return
        
    raw_log_a = np.log10(da[mask])
    raw_log_b = np.log10(db[mask])
    
    x_data = raw_log_b - raw_log_b.mean()
    y_data = raw_log_a - raw_log_a.mean()
    
    slope, intercept, r2 = compute_loglog_fit_simple(x_data, y_data)

    ax.scatter(x_data, y_data, alpha=0.5, s=15, c='#1f77b4', edgecolors='none', rasterized=True)
    
    max_val = max(np.max(np.abs(x_data)), np.max(np.abs(y_data)))
    limit = max_val * 1.3 
    x_ref = np.array([-limit, limit])
    

    ax.plot(x_ref, 1.0 * x_ref, color='#d62728', linestyle='--', lw=2.5, alpha=0.9, label='Slope=1')
    ax.plot(x_ref, 2.0 * x_ref, color='#2ca02c', linestyle='-.', lw=2.5, alpha=0.9, label='Slope=2')
    

    y_fit = slope * x_ref + intercept
    ax.plot(x_ref, y_fit, 'k-', lw=2.0, alpha=0.8, label='Fit')


    ax.set_aspect('equal')
    ax.set_xlim([-limit, limit])
    ax.set_ylim([-limit, limit])
    

    clean_a = clean_tex_label(key_a)
    clean_b = clean_tex_label(key_b)
    title_str = fr"$\mathbf{{{clean_a}}}$ vs $\mathbf{{{clean_b}}}$"
    ax.set_title(title_str, fontsize=16, pad=10) 
    
    info_text = f"$\\alpha = {slope:.2f}$\n$R^2 = {r2:.2f}$"
    ax.text(0.05, 0.95, info_text, transform=ax.transAxes, 
            fontsize=14, verticalalignment='top', fontweight='bold',
            bbox=dict(boxstyle="square,pad=0.4", fc="white", ec="black", lw=1.5, alpha=0.9))


    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(2.0) 
        spine.set_color('black')


    ax.axhline(0, color='gray', linewidth=1.0, alpha=0.4)
    ax.axvline(0, color='gray', linewidth=1.0, alpha=0.4)
    

    ax.tick_params(axis='both', which='major', labelsize=12, width=1.5, length=5, direction='in')


def plot_all_scatters_in_grid(epoch, base_path1, base_path2, slice_range, pair_list, config, run_idx=0):

    apply_icml_appendix_style() 
    
    print(f"Processing Epoch {epoch} for Centered Grid Plot...")
    
    data = load_and_merge_data(epoch, base_path1, base_path2, run_idx=run_idx)
    if data is None: 
        print("Data not found.")
        return


    H = torch.tensor(data["Hessian"]).float()
    if not torch.isfinite(H).all(): H = torch.nan_to_num(H)
    hessian_eig = torch.linalg.eigh(H)


    n_plots = len(pair_list)
    cols = 4 
    rows = math.ceil(n_plots / cols)
    

    fig, axes = plt.subplots(rows, cols, figsize=(24, 6 * rows), constrained_layout=True)
    axes = axes.flatten()

    for i, pair in enumerate(pair_list):
        _plot_scatter_on_ax(axes[i], data, hessian_eig, pair, slice_range)
        

        row_idx = i // cols
        col_idx = i % cols

        if row_idx == rows - 1:
            axes[i].set_xlabel(r"Centered $\log_{10}(X_{ii})$", fontsize=16)
        

        if col_idx == 0:
            axes[i].set_ylabel(r"Centered $\log_{10}(Y_{ii})$", fontsize=16)


    for j in range(i + 1, len(axes)):
        axes[j].axis('off')


    handles, labels = axes[0].get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    
    fig.legend(by_label.values(), by_label.keys(), 
               loc='upper center', 
               bbox_to_anchor=(0.5, 1.02), 
               ncol=4, 
               frameon=True,
               edgecolor='black',
               fancybox=False, 
               fontsize=18,    
               borderpad=0.8)

    fig.suptitle(f"Appendix: Centered Log-Log Analysis (Epoch {epoch})", fontsize=24, fontweight='bold', y=1.05)

    
    plt.show()
    del data, hessian_eig


if __name__ == "__main__":

    slice_range = torch.arange(0, 1000)
    net_size = 1000
    n_class = 10
    train_size = 2000
    config = {'B':100, 'alpha': 0.1, 'lss_fn': 'cse', 'model': 'MLP', 'dataset': 'cifar10'}
    run_idx = 4
    sample_number = 20

    # 
    save_dir = f"./AWCH_data/NS{net_size}_TrainSize{train_size}_SampleN{sample_number}_ClassN{n_class}_B{config['B']}lr{config['alpha']}_lossfn_{config['lss_fn']}_model_{config['model']}_dataset_{config['dataset']}"

    base_path1 = os.path.join(save_dir, "C_epoch_")
    base_path2 = os.path.join(save_dir, "C_epoch_") 

    pairs_to_plot_grid = [
        ("C", "Hessian"),
        ("C1", "Hessian"),
        ("C1_dia", "Hessian"),
        ("C1_dia_w_dia", "Hessian"),
        ("H_2_d", "Hessian"),
        ("Covar", "Hessian"),
        ("H_2_d", "Covar"),
        ("C", "Covar"),
        ("C", "C1"),
        ("C1", "C1_dia"),
        ("C1_dia", "C1_dia_w_dia"),
        ("H_2_d", "C1_dia_w_dia")
    ]
    
    plot_all_scatters_in_grid(150, base_path1, base_path2, slice_range, pairs_to_plot_grid, config, run_idx)