In [None]:
import model_config
import torch
import os

max_e = 150 #####  task idxs
sample_number = 20    ## number of samples used to compute AWCH
train_size =2000 ## number of training samples per class

config = model_config.set_config('none',test_size = 1000, train_size = train_size, max_epoch=max_e)
config['add_regulization'] = False
config['beta'] = 10
config['lss_fn'] = 'cse'
config['dataset'] = 'cifar10'  # 'mnist'  #'cifar10'  
config['model'] = 'MLP' # 'CNN'  #'FC' #'CNN' 
config['sample_holder'] = [0,1,2,3,4,5,6,7,8,9]  ## the samples used to compute AWCH
config['class_number'] = 10
config['B'] = 100 # 30
config['alpha'] = 0.1 # 0.5
net_size = 1000  

loss_fn = config['lss_fn']
sample_num  = config['train_size']
sample_holder = config['sample_holder']
mis_label_prob = config['rho']

batch_size = config['B']
# config['sample_holder'] = []
# config['class_number'] = 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
load_path = f"./AWCH_data/TrainSize{train_size}_SampleN{sample_number}_ClassN{len(config['sample_holder'])}_B{config['B']}lr{config['alpha']}_lossfn_{config['lss_fn']}_model_{config['model']}_dataset_{config['dataset']}/h_g_tensors_holder_epoch_{config['max_epoch']}.pt"
load_path = f"./AWCH_data/NS{net_size}_TrainSize{train_size}_SampleN{sample_number}_ClassN{len(config['sample_holder'])}_B{config['B']}lr{config['alpha']}_lossfn_{config['lss_fn']}_model_{config['model']}_dataset_{config['dataset']}/h_g_tensors_holder_epoch_{config['max_epoch']}.pt"
print(f"Loading data from: {load_path}")
loaded_data = torch.load(load_path,weights_only=False)
h_holder= loaded_data['h_holder']
g_holder= loaded_data['g_holder']
components= loaded_data['components']
Covar = loaded_data['Covar']
Hessian = loaded_data['matrix']
del loaded_data
components = torch.tensor(components).to(device)
h_holder = torch.tensor(h_holder).to(device)
h_holder = components @ h_holder @ components.T


In [None]:
import numpy as np
import matplotlib.pyplot as plt



plt.rcParams['font.sans-serif'] = ['Microsoft YaHei', 'SimHei'] 

plt.rcParams['axes.unicode_minus'] = False 


def analyze_eigen_structure(h_data, metric_mode='ipr', smoothing_window=50):
    
    colors = plt.cm.tab10(np.linspace(0, 1, 10))
    fig, axes = plt.subplots(1, 2, figsize=(16, 7))
    ax_spec = axes[0]
    ax_struct = axes[1]
    
    if isinstance(h_data[0], torch.Tensor):
        H_DIM = h_data[0].shape[0]
    else:
        H_DIM = h_data[0].shape[0]
    
    
    total_samples = len(h_data)
    num_classes = 10
    samples_per_class = total_samples // num_classes
    
    for class_idx in range(num_classes):
        for sample_i in range(1): 
            global_idx = class_idx * samples_per_class + sample_i
            if global_idx >= total_samples: break
            
            H = h_data[global_idx]
            if isinstance(H, torch.Tensor):
                H = H.cpu().numpy()
            
            if not np.isfinite(H).all():
                H = np.nan_to_num(H, nan=0.0)
            
            try:
                w, v = np.linalg.eigh(H)
            except np.linalg.LinAlgError:
                continue
                
            idx_desc = np.argsort(w)[::-1]
            w_desc = w[idx_desc]
            v_desc = v[:, idx_desc] # (Dim, Rank)
            
            ranks = np.arange(1, len(w_desc) + 1)
            

            V_squared = v_desc ** 2
            
            if metric_mode == 'max_weight':
                metric_vals = np.max(V_squared, axis=0)
                ylabel = r'Max Squared Component ($\max v_{ki}^2$)'
                title_suffix = 'High Value = Localized'
                
            elif metric_mode == 'pr_normalized':
                sum_sq = np.sum(V_squared, axis=0)
                sum_quad = np.sum(V_squared ** 2, axis=0) 
                pr_values = (sum_sq ** 2) / sum_quad
                metric_vals = pr_values / H_DIM
                ylabel = r'Normalized Participation Ratio ($PR / N$)'
                title_suffix = 'High Value (~1) = Delocalized'
            
            elif metric_mode == 'ipr':
                metric_vals = np.sum(v_desc**4, axis=0) # IPR = sum(v^4)
                ylabel = r'Inverse Participation Ratio ($\sum v_{ki}^4$)'
                title_suffix = 'High Value = Localized'

            else:
                raise ValueError("metric_mode must be 'max_weight', 'pr_normalized', or 'ipr'")


            label_txt = f'Class {class_idx}' if sample_i == 0 else None
            
            # Task 1: Spectrum
            ax_spec.loglog(ranks, np.abs(w_desc), color=colors[class_idx], alpha=0.5, linewidth=1.5)
            ax_spec.set_ylim(1e-20, 1e2)
            # Task 2: Structure (Structure Plot)
            ax_struct.loglog(ranks, metric_vals, color=colors[class_idx], 
                             alpha=0.15, marker='.', markersize=1, linestyle='None')
            
            if len(metric_vals) > smoothing_window:
                kernel = np.ones(smoothing_window) / smoothing_window
                ma = np.convolve(metric_vals, kernel, mode='valid')
                ma_ranks = ranks[smoothing_window-1:]
                ax_struct.loglog(ma_ranks, ma, color=colors[class_idx], alpha=0.8, linewidth=1.5, label=label_txt)

    ax_spec.set_title('Task 1: Eigenvalue Spectrum', fontsize=14)
    ax_spec.set_xlabel('Rank Index $i$', fontsize=12)
    ax_spec.set_ylabel(r'Eigenvalue $|\lambda_i|$', fontsize=12)
    ax_spec.grid(True, which="both", alpha=0.3)
    
    ax_struct.set_title(f'Task 2: Eigenvector Structure ({metric_mode.upper()})\n{title_suffix}', fontsize=14)
    ax_struct.set_xlabel('Rank Index $i$', fontsize=12)
    ax_struct.set_ylabel(ylabel, fontsize=12)
    ax_struct.grid(True, which="both", alpha=0.3)
    
    ax_struct.legend(loc='best', fontsize=8, ncol=2)
    
    plt.tight_layout()
    plt.show()


analyze_eigen_structure(h_holder, metric_mode='ipr')

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import linregress

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def set_icml_style():
    plt.rcParams.update({
        'font.family': 'serif',
        'font.serif': ['Times New Roman', 'DejaVu Serif', 'serif'],
        'font.size': 14,
        'axes.labelsize': 15,
        'axes.titlesize': 15,
        'legend.fontsize': 12,
        'xtick.labelsize': 12,
        'ytick.labelsize': 12,
        'lines.linewidth': 1.5,
        'lines.markersize': 2,
        'grid.alpha': 0.3,
        'axes.grid': True,
        'savefig.dpi': 300,
        'figure.autolayout': True
    })
def get_top_k_decomposition(h_holder, k=800):
    print(f"1. Performing Eigen Decomposition (Full Rank)...")
    h_holder = h_holder.to(device)
    vals, vecs = torch.linalg.eigh(h_holder)
    vals = torch.flip(vals, dims=[1])
    vecs = torch.flip(vecs, dims=[2])
    N = vals.shape[1]
    actual_k = min(k, N)
    print(f"2. Filtering Top {actual_k} modes...")
    vals_k = vals[:, :actual_k]
    vecs_k = vecs[:, :, :actual_k]
    del vals, vecs
    torch.cuda.empty_cache()
    return vals_k, vecs_k

import torch
import numpy as np

def compute_H_C_only(eigenvalues, eigenvectors, chunk_size=16):

    N_samples = eigenvalues.shape[0]
    
    N_dim = eigenvectors.shape[1] 
    H_mean_accumulator = np.zeros(N_dim, dtype=np.float32)
    C_mean_accumulator = np.zeros(N_dim, dtype=np.float32)

    with torch.no_grad():
        for i in range(0, N_samples, chunk_size):
            end_idx = min(i + chunk_size, N_samples)
            
            lam_chunk = eigenvalues[i:end_idx]   # (Chunk, N)
            vec_chunk = eigenvectors[i:end_idx]  # (Chunk, N, N)

            vec_chunk = torch.flip(vec_chunk, dims=[1]) 
            
            V_sq_chunk = vec_chunk ** 2  # (Chunk, N, N)


            
            H_batch_chunk = torch.matmul(V_sq_chunk, lam_chunk.unsqueeze(-1)).squeeze(-1)
            
            lam_sq_chunk = lam_chunk ** 2
            C_batch_chunk = torch.matmul(V_sq_chunk, lam_sq_chunk.unsqueeze(-1)).squeeze(-1)

            H_mean_accumulator += torch.sum(H_batch_chunk, dim=0).cpu().numpy()
            C_mean_accumulator += torch.sum(C_batch_chunk, dim=0).cpu().numpy()
            
            del lam_chunk, vec_chunk, V_sq_chunk, H_batch_chunk, C_batch_chunk
            
    return H_mean_accumulator / N_samples, C_mean_accumulator / N_samples


def compute_full_gap_matrix_efficient(eigenvalues, eigenvectors, batch_size=32):
    print(f"Computing Gap Matrix (Batch Size={batch_size})...")
    B, N, num_modes = eigenvectors.shape 
    eigenvectors = torch.flip(eigenvectors, dims=[1])
    total_gap_sum = torch.zeros((N, num_modes), device=device)
    with torch.no_grad():
        for b_start in range(0, B, batch_size):
            b_end = min(b_start + batch_size, B)
            vec_chunk = eigenvectors[b_start:b_end] 
            val_chunk = eigenvalues[b_start:b_end]      
            lam_sq = (val_chunk ** 2).unsqueeze(1)
            V_sq = vec_chunk ** 2
            numerator = torch.cumsum(lam_sq * V_sq, dim=2) 
            total_energy = numerator[:, :, -1].unsqueeze(2)

            denominator = total_energy 
            gap_chunk = numerator / denominator 

            total_gap_sum += torch.sum(gap_chunk, dim=0)
            del vec_chunk, val_chunk, lam_sq, V_sq, numerator, denominator, gap_chunk
            torch.cuda.empty_cache()
        
    avg_gap_matrix = total_gap_sum / B
    return avg_gap_matrix.cpu().numpy()





def generate_artificial_spectra(N, mode='elbow', k=10 , alpha=1.5, k_target=800):
    indices = torch.arange(1, k_target + 1, device=device).float()
    if mode == 'elbow':
        spectrum = 10 * indices **(-2)
        spectrum[:k] = 0.1 
    elif mode == 'smooth':
        spectrum = indices ** (-alpha) * 1.0
    return spectrum

def create_gap_from_real_spectrum(eigenvalues, idx = 10):
    gap_spectrum = eigenvalues.clone()
    # gap_spectrum[:,0] = torch.mean(gap_spectrum[:,0], dim=0)
    gap_spectrum[:, idx:] = gap_spectrum[:, idx:] * 1e-5
    return gap_spectrum


import torch
import numpy as np

def reshape_the_real_spectrum(eigenvalues, n_fit=10, slope_factor=1.0, noise_std=1.0):

    device = eigenvalues.device
    smoothed = eigenvalues.clone()
    B = eigenvalues.shape[0]
    
    target_ranks = torch.arange(1, n_fit + 1, device=device).float()
    log_target_ranks = torch.log(target_ranks)
    
    vals_to_fit = eigenvalues[:, :n_fit]
    log_vals = torch.log(vals_to_fit)
    
    log_ranks_mean = log_target_ranks.mean()
    log_vals_mean = log_vals.mean(dim=1, keepdim=True) # [B, 1]
    
    X_centered = log_target_ranks - log_ranks_mean # [n_fit]
    Y_centered = log_vals - log_vals_mean          # [B, n_fit]
    
    denom = (X_centered ** 2).sum()
    numer = (Y_centered * X_centered.unsqueeze(0)).sum(dim=1, keepdim=True)
    current_slope = numer / denom  # [B, 1]
    
    target_slope = current_slope * slope_factor
    

    bias_noise = 0.0
    if noise_std > 0.0:
        bias_noise = torch.randn(B, 1, device=device) * noise_std
    

    anchor_idx = n_fit - 1
    log_rank_anchor = log_target_ranks[anchor_idx]
    log_val_anchor = log_vals[:, anchor_idx].unsqueeze(1) # [B, 1]

    log_predicted = log_val_anchor + target_slope * (log_target_ranks.unsqueeze(0) - log_rank_anchor) + bias_noise
    
    smoothed[:, :n_fit] = torch.exp(log_predicted)
    
    return smoothed
    

import numpy as np
import matplotlib.pyplot as plt
import torch
from scipy.stats import linregress



def plot_comparison_loglog(ax, data_orig, data_new, label_name, CH_slice):
    """
    Helper function to plot log-log comparison between Original and New metric.
    Visuals optimized for presentation.
    """
    x = data_orig[CH_slice]
    y = data_new[CH_slice]
    
    mask = (x > 0) & (y > 0)
    x_clean, y_clean = x[mask], y[mask]

    if len(x_clean) > 1:
        log_x, log_y = np.log10(x_clean), np.log10(y_clean)
        slope, intercept, r_val, _, _ = linregress(log_x, log_y)
        
        ax.scatter(x_clean, y_clean, s=30, alpha=0.6, color='royalblue', edgecolors='white', linewidth=0.5, label='Data Points')
        
        fit_y = 10**(intercept + slope * log_x)
        ax.plot(x_clean, fit_y, color='firebrick', linestyle='-', linewidth=3.0, label=f'Fit: slope={slope:.2f}')
        
        min_val = min(x_clean.min(), y_clean.min())
        max_val = max(x_clean.max(), y_clean.max())
        ref_line = np.linspace(min_val, max_val, 100)
        ax.plot(ref_line, ref_line, 'k--', linewidth=2.0, alpha=0.8, label='Ref: slope=1')
    
    ax.set_xscale('log')
    ax.set_yscale('log')
    
    ax.set_xlabel(f'Original {label_name}', fontsize=14, fontweight='bold')
    ax.set_ylabel(f'Current {label_name}', fontsize=14, fontweight='bold')
    ax.set_title(f'{label_name} Comparison', fontsize=16, fontweight='bold')
    
    ax.legend(loc='best', fontsize=12, frameon=True, framealpha=0.9)
    ax.grid(True, which="both", ls="-", alpha=0.4, color='gray') 

def analyze_slope_and_plot(CH_slice,H, C, gap_matrix, spectrum_data, title, 
                           H_ref, C_ref, 
                           ax_scatter, ax_gap, ax_spectrum, ax_comp_c, ax_comp_h,
                           end_idx=800):
    
    # --- 1. C_ii vs H_ii Scatter Plot ---
    x_data = H[CH_slice]
    y_data = C[CH_slice]
    mask = (x_data > 0) & (y_data > 0)
    x_data, y_data = x_data[mask], y_data[mask]
    
    if len(x_data) > 1:
        log_x, log_y = np.log10(x_data), np.log10(y_data)
        slope, intercept, _, _, _ = linregress(log_x, log_y)
        
        ax_scatter.scatter(x_data, y_data, s=6, alpha=0.5, color='navy', label='Data Points')
        
        fit_line_y = 10**(intercept + slope * log_x)
        ax_scatter.plot(x_data, fit_line_y, color='red', linestyle='-', linewidth=1, label=f'Slope = {slope:.2f}')
        
        # Reference lines
        mid_x, mid_y = np.mean(log_x), np.mean(log_y)
        ref_line_x = np.logspace(min(log_x), max(log_x), 100)
        
        ref_line_y1 = 10**(1.0 * (np.log10(ref_line_x) - mid_x) + mid_y)
        ax_scatter.plot(ref_line_x, ref_line_y1, 'k--', linewidth=2.0, label='Slope=1')
        
        ref_line_y2 = 10**(2.0 * (np.log10(ref_line_x) - mid_x) + mid_y)
        ax_scatter.plot(ref_line_x, ref_line_y2, color='darkgreen', linestyle=':', linewidth=2.5, label='Slope=2')
    
    ax_scatter.set_xscale('log')
    ax_scatter.set_yscale('log')
    ax_scatter.set_xlabel(r'$H_{ii}$', fontsize=14, fontweight='bold')
    ax_scatter.set_ylabel(r'$C_{ii}$', fontsize=14, fontweight='bold')
    ax_scatter.set_title(title, fontsize=16, fontweight='bold')
    ax_scatter.legend(loc='upper left', fontsize=12)
    ax_scatter.grid(True, which="both", ls="-", alpha=0.3)
    
    # --- 2. Gap Curve Plot ---
    real_k = gap_matrix.shape[1]
    k_values = np.arange(1, real_k + 1)
    avg_gap_curve = np.mean(gap_matrix[:1500], axis=0)
    # print(f"Gap Matrix Shape: {gap_matrix.shape}")
    # print(avg_gap_curve)
    indices_to_plot = [1, 10, 50, 100, 200, 400, 700, 1000, 1500]
    colors_gap = plt.cm.viridis(np.linspace(0, 0.9, len(indices_to_plot))) 
    
    for idx, color in zip(indices_to_plot, colors_gap):
        if idx < gap_matrix.shape[0]:
            ax_gap.semilogx(k_values, gap_matrix[idx, :], color=color, linewidth=2.0, alpha=0.8, label=f'i={idx+1}')
            
    ax_gap.semilogx(k_values, avg_gap_curve, 'k--', linewidth=3.5, alpha=1.0, label='Avg')
    
    ax_gap.legend(loc='upper right', fontsize=10, ncol=2) 
    ax_gap.set_title(f"Gap Metric", fontsize=16, fontweight='bold')
    ax_gap.grid(True, which="both", ls="-", alpha=0.3)
    ax_gap.set_xlabel('k (Cutoff)', fontsize=14)
    ax_gap.set_ylabel('Cumulative contribution', fontsize=14)
    ax_gap.set_ylim(1e-3, 1.05)
    
    # --- 3. Spectrum Plot (Sampled) ---
    if isinstance(spectrum_data, torch.Tensor):
        spec_np = spectrum_data.detach().cpu().numpy()
    else:
        spec_np = spectrum_data
        
    B_samples = spec_np.shape[0]
    num_classes = 10
    colors = plt.cm.tab10(np.linspace(0, 1, num_classes))
    ranks = np.arange(1, spec_np.shape[1] + 1)
    
    if B_samples >= num_classes:
        samples_per_class = B_samples // num_classes
        selected_indices = [
            i * samples_per_class+1
            for i in range(num_classes)
        ]
    else:
        selected_indices = range(B_samples)
    
    for i, idx in enumerate(selected_indices):
        if idx < B_samples:
            lbl = f"class {i}" 
            ax_spectrum.loglog(ranks, spec_np[idx], color=colors[i % num_classes], 
                               alpha=0.8, linewidth=2.5, label=lbl)
            
    ax_spectrum.axvline(x=10, color='dimgray', linestyle=':', linewidth=2.0, alpha=0.8, label='k=10')
    
    ax_spectrum.set_title('Eigenvalue Spectrum (Sampled)', fontsize=16, fontweight='bold')
    ax_spectrum.set_xlabel('Rank', fontsize=14)
    ax_spectrum.set_ylabel(r'$|\lambda_i|$', fontsize=14)
    ax_spectrum.grid(True, which="both", alpha=0.3)
    ax_spectrum.legend(loc='best', fontsize=10, ncol=2)

    # --- 4. NEW: C_ii Comparison ---
    plot_comparison_loglog(ax_comp_c, C_ref, C, r'$C_{ii}$', CH_slice)

    # --- 5. NEW: H_ii Comparison ---
    plot_comparison_loglog(ax_comp_h, H_ref, H, r'$H_{ii}$', CH_slice)
###########################################################################################
def main(CH_slice, h_holder_input=None, a1=1.0, a2=1.6, a3=3.0, k_target=800):
    
    h_holder = h_holder_input
    B, N, _ = h_holder.shape
    print(f"Processing input data: B={B}, N={N}")

    # 1. Real Eigen Decomposition
    torch.backends.cuda.preferred_linalg_library('magma')
    real_eigenvalues_orig, eigenvectors = get_top_k_decomposition(h_holder, k=k_target)
    
    del h_holder 
    torch.cuda.empty_cache()

    results = []
    
    # ==========================================
    # A1. Real Data (Original) - Baseline
    # ==========================================
    print("\n--- Processing Real Data (Original) ---")
    


    H_r_orig, C_r_orig = compute_H_C_only(real_eigenvalues_orig, eigenvectors)
    Gap_r_orig = compute_full_gap_matrix_efficient(real_eigenvalues_orig, eigenvectors)
    
    results.append({
        'H': H_r_orig, 'C': C_r_orig, 'Gap': Gap_r_orig, 
        'Spec': real_eigenvalues_orig, 'Title': "Real Data (Original)"
    })

    # ==========================================
    # A2. Real Data (Smoothed / Modified)
    # ==========================================
    print("\n--- Processing Real Data (Modified) ---")
    real_eigenvalues_smooth1 = create_gap_from_real_spectrum(real_eigenvalues_orig, idx=3)
    # real_eigenvalues_smooth1[:,0] = torch.mean(real_eigenvalues_orig[:,0], dim=0)

    
    H_r_smooth1, C_r_smooth1 = compute_H_C_only(real_eigenvalues_smooth1, eigenvectors)
    Gap_r_smooth1 = compute_full_gap_matrix_efficient(real_eigenvalues_smooth1, eigenvectors)
    
    results.append({
        'H': H_r_smooth1, 'C': C_r_smooth1, 'Gap': Gap_r_smooth1, 
        'Spec': real_eigenvalues_smooth1, 'Title': "Real Data (Modified)"
    })
    
    print("\n--- Processing Real Data (Modified) ---")
    real_eigenvalues_smooth2 = create_gap_from_real_spectrum(real_eigenvalues_orig, idx=1)
    # real_eigenvalues_smooth1[:,0] = torch.mean(real_eigenvalues_orig[:,0], dim=0)

    
    H_r_smooth2, C_r_smooth2 = compute_H_C_only(real_eigenvalues_smooth2, eigenvectors)
    Gap_r_smooth2 = compute_full_gap_matrix_efficient(real_eigenvalues_smooth2, eigenvectors)
    
    results.append({
        'H': H_r_smooth2, 'C': C_r_smooth2, 'Gap': Gap_r_smooth2, 
        'Spec': real_eigenvalues_smooth2, 'Title': "Real Data (Modified)"
    })

    
    total_rows = len(results)
    print(f"\nGenerating Plots ({total_rows} rows x 5 columns)...")
    
    # Col 1: C vs H
    # Col 2: Gap
    # Col 3: Spectrum
    # Col 4: C vs C_orig (LogLog + Fit)
    # Col 5: H vs H_orig (LogLog + Fit)
    fig, axes = plt.subplots(total_rows, 5, figsize=(30, 4 * total_rows))
    if total_rows == 1: axes = np.expand_dims(axes, 0)

    for i, res in enumerate(results):
        analyze_slope_and_plot(CH_slice,
            res['H'], res['C'], res['Gap'], res['Spec'], res['Title'],
            H_r_orig, C_r_orig,  
            axes[i, 0], axes[i, 1], axes[i, 2], axes[i, 3], axes[i, 4],  
            end_idx=k_target)
    
    plt.tight_layout()
    plt.show()
    print("Done.")

if __name__ == "__main__":
    if 'h_holder' in locals():
        CH_slice = torch.arange(0, 1500)
        main(CH_slice, h_holder, a1=1.0, a2=1.6, a3=3.0, k_target=2560)
    else:
        print("Please define 'h_holder' (Tensor shape B, N, N) before running.")

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import linregress
import matplotlib.ticker as ticker

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def set_icml_style():
    plt.rcParams.update({
        'font.family': 'serif',
        'font.serif': ['Times New Roman', 'Times', 'DejaVu Serif'],
        'mathtext.fontset': 'stix',  
        'font.size': 12,
        'axes.labelsize': 14,
        'axes.titlesize': 14,
        'xtick.labelsize': 12,
        'ytick.labelsize': 12,
        'legend.fontsize': 10,
        'lines.linewidth': 1.5,
        'lines.markersize': 4,
        
        'axes.grid': True,
        'grid.alpha': 0.3,
        'grid.linestyle': '--',
        'grid.color': '#b0b0b0', #
        'savefig.dpi': 300,
        'savefig.bbox': 'tight',
        'figure.autolayout': False, # 
        
        'pdf.fonttype': 42,
        'ps.fonttype': 42
    })

def format_ax(ax):
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.tick_params(direction='out', length=4, width=1)



def get_top_k_decomposition(h_holder, k=800):
    print(f"1. Performing Eigen Decomposition (Full Rank)...")
    h_holder = h_holder.to(device)
    vals, vecs = torch.linalg.eigh(h_holder)
    vals = torch.flip(vals, dims=[1])
    vecs = torch.flip(vecs, dims=[2])
    N = vals.shape[1]
    actual_k = min(k, N)
    print(f"2. Filtering Top {actual_k} modes...")
    vals_k = vals[:, :actual_k]
    vecs_k = vecs[:, :, :actual_k]
    del vals, vecs
    torch.cuda.empty_cache()
    return vals_k, vecs_k

def compute_H_C_only(eigenvalues, eigenvectors, chunk_size=16):
    N_samples = eigenvalues.shape[0]
    N_dim = eigenvectors.shape[1]
    H_mean_accumulator = np.zeros(N_dim, dtype=np.float32)
    C_mean_accumulator = np.zeros(N_dim, dtype=np.float32)

    with torch.no_grad():
        for i in range(0, N_samples, chunk_size):
            end_idx = min(i + chunk_size, N_samples)
            lam_chunk = eigenvalues[i:end_idx]
            vec_chunk = eigenvectors[i:end_idx]
            vec_chunk = torch.flip(vec_chunk, dims=[1]) 
            V_sq_chunk = vec_chunk ** 2
            
            H_batch_chunk = torch.matmul(V_sq_chunk, lam_chunk.unsqueeze(-1)).squeeze(-1)
            lam_sq_chunk = lam_chunk ** 2
            C_batch_chunk = torch.matmul(V_sq_chunk, lam_sq_chunk.unsqueeze(-1)).squeeze(-1)

            H_mean_accumulator += torch.sum(H_batch_chunk, dim=0).cpu().numpy()
            C_mean_accumulator += torch.sum(C_batch_chunk, dim=0).cpu().numpy()
            
            del lam_chunk, vec_chunk, V_sq_chunk, H_batch_chunk, C_batch_chunk
            
    return H_mean_accumulator / N_samples, C_mean_accumulator / N_samples

def compute_full_gap_matrix_efficient(eigenvalues, eigenvectors, batch_size=32):
    print(f"Computing Gap Matrix (Batch Size={batch_size})...")
    B, N, num_modes = eigenvectors.shape 
    eigenvectors = torch.flip(eigenvectors, dims=[1])
    total_gap_sum = torch.zeros((N, num_modes), device=device)
    with torch.no_grad():
        for b_start in range(0, B, batch_size):
            b_end = min(b_start + batch_size, B)
            vec_chunk = eigenvectors[b_start:b_end] 
            val_chunk = eigenvalues[b_start:b_end]      
            lam_sq = (val_chunk ** 2).unsqueeze(1)
            V_sq = vec_chunk ** 2
            numerator = torch.cumsum(lam_sq * V_sq, dim=2) 
            total_energy = numerator[:, :, -1].unsqueeze(2)
            denominator = total_energy 
            gap_chunk = numerator / denominator 
            total_gap_sum += torch.sum(gap_chunk, dim=0)
            del vec_chunk, val_chunk, lam_sq, V_sq, numerator, denominator, gap_chunk
            torch.cuda.empty_cache()
        
    avg_gap_matrix = total_gap_sum / B
    return avg_gap_matrix.cpu().numpy()

def create_gap_from_real_spectrum(eigenvalues, idx = 10):
    gap_spectrum = eigenvalues.clone()
    gap_spectrum[:, idx:] = gap_spectrum[:, idx:] * 1e-5
    return gap_spectrum



def plot_comparison_loglog(ax, data_orig, data_new, label_name, CH_slice):

    format_ax(ax) # 
    x = data_orig[CH_slice]
    y = data_new[CH_slice]
    
    mask = (x > 0) & (y > 0)
    x_clean, y_clean = x[mask], y[mask]

    if len(x_clean) > 1:
        log_x, log_y = np.log10(x_clean), np.log10(y_clean)
        slope, intercept, r_val, _, _ = linregress(log_x, log_y)        
        ax.scatter(x_clean, y_clean, s=2, facecolors='none', edgecolors='#1f77b4', alpha=0.6, linewidth=0.8, label='Data Points')
        fit_y = 10**(intercept + slope * log_x)
        ax.plot(x_clean, fit_y, color='#d62728', linestyle='-', linewidth=2.0, label=f'Fit: slope={slope:.2f}')
        min_val = min(x_clean.min(), y_clean.min())
        max_val = max(x_clean.max(), y_clean.max())
        ref_line = np.linspace(min_val, max_val, 100)
        ax.plot(ref_line, ref_line, 'k--', linewidth=1.0, alpha=0.6, label='y=x')
    
    ax.set_xscale('log')
    ax.set_yscale('log')    
    ax.set_xlabel(f'Original {label_name}')
    ax.set_ylabel(f'Current {label_name}')
    ax.set_title(f'{label_name} Comparison')
    
    ax.legend(loc='lower right', frameon=True, framealpha=0.9, edgecolor='gray', fontsize=9)


def analyze_slope_and_plot(CH_slice, H, C, gap_matrix, spectrum_data, row_title, 
                           H_ref, C_ref, 
                           ax_scatter, ax_gap, ax_spectrum, ax_comp_c, ax_comp_h,
                           end_idx=800):
    
    # --- 1. C_ii vs H_ii Scatter Plot ---
    format_ax(ax_scatter)
    x_data = H[CH_slice]
    y_data = C[CH_slice]
    mask = (x_data > 0) & (y_data > 0)
    x_data, y_data = x_data[mask], y_data[mask]
    
    if len(x_data) > 1:
        log_x, log_y = np.log10(x_data), np.log10(y_data)
        slope, intercept, _, _, _ = linregress(log_x, log_y)
        
        ax_scatter.scatter(x_data, y_data, s=2, alpha=0.5, color='#000080', label='Data', edgecolors='none')
        
        fit_line_y = 10**(intercept + slope * log_x)
        ax_scatter.plot(x_data, fit_line_y, color='#d62728', linestyle='-', linewidth=1.5, label=f'Fit: {slope:.2f}')
        
        mid_x, mid_y = np.mean(log_x), np.mean(log_y)
        ref_line_x = np.logspace(min(log_x), max(log_x), 100)
        
        ref_line_y1 = 10**(1.0 * (np.log10(ref_line_x) - mid_x) + mid_y)
        ax_scatter.plot(ref_line_x, ref_line_y1, 'k--', linewidth=1.0, label='Slope=1')
        
        ref_line_y2 = 10**(2.0 * (np.log10(ref_line_x) - mid_x) + mid_y)
        ax_scatter.plot(ref_line_x, ref_line_y2, color='green', linestyle=':', linewidth=1.5, label='Slope=2')
    
    ax_scatter.set_xscale('log')
    ax_scatter.set_yscale('log')
    ax_scatter.set_xlabel(r'$H_{ii}$')
    ax_scatter.set_ylabel(r'$C_{ii}$')
    ax_scatter.set_title(f"Scalability ({row_title})") 
    ax_scatter.legend(loc='upper left', fontsize=8, frameon=False)
    
    # --- 2. Gap Curve Plot ---
    format_ax(ax_gap)
    real_k = gap_matrix.shape[1]
    k_values = np.arange(1, real_k + 1)
    avg_gap_curve = np.mean(gap_matrix[:1500], axis=0)
    
    indices_to_plot = [1, 10, 50, 100, 400, 1000]
    colors_gap = plt.cm.viridis(np.linspace(0, 0.9, len(indices_to_plot))) 
    
    for idx, color in zip(indices_to_plot, colors_gap):
        if idx < gap_matrix.shape[0]:
            ax_gap.semilogx(k_values, gap_matrix[idx, :], color=color, linewidth=1.2, alpha=0.7, label=f'i={idx+1}')
            
    ax_gap.semilogx(k_values, avg_gap_curve, 'k--', linewidth=2.0, alpha=1.0, label='Avg')
    
    ax_gap.legend(loc='lower right', fontsize=8, ncol=2, frameon=False) 
    ax_gap.set_title("Gap Metric")
    ax_gap.set_xlabel('k (Cutoff)')
    ax_gap.set_ylabel('Cum. Distribution')
    ax_gap.set_ylim(0, 1.1)
    
    # --- 3. Spectrum Plot (Sampled) ---
    format_ax(ax_spectrum)
    if isinstance(spectrum_data, torch.Tensor):
        spec_np = spectrum_data.detach().cpu().numpy()
    else:
        spec_np = spectrum_data
        
    B_samples = spec_np.shape[0]
    num_classes = 10 
    colors = plt.cm.Set1(np.linspace(0, 1, num_classes))
    ranks = np.arange(1, spec_np.shape[1] + 1)
    
    selected_indices = np.linspace(0, B_samples-1, num_classes, dtype=int)
    
    for i, idx in enumerate(selected_indices):
        ax_spectrum.loglog(ranks, spec_np[idx], color=colors[i], alpha=0.7, linewidth=1.5)
            
    ax_spectrum.axvline(x=10, color='gray', linestyle=':', linewidth=1.0)
    ax_spectrum.set_title('Spectrum (Sampled)')
    ax_spectrum.set_xlabel('Rank')
    ax_spectrum.set_ylabel(r'$|\lambda_i|$')

    # --- 4. NEW: C_ii Comparison ---
    plot_comparison_loglog(ax_comp_c, C_ref, C, r'$C_{ii}$', CH_slice)

    # --- 5. NEW: H_ii Comparison ---
    plot_comparison_loglog(ax_comp_h, H_ref, H, r'$H_{ii}$', CH_slice)



def main(CH_slice, h_holder_input=None, fig_title="Neural Collapse Spectral Analysis"):

    set_icml_style()
    
    h_holder = h_holder_input
    B, N, _ = h_holder.shape
    print(f"Processing input data: B={B}, N={N}")

    torch.backends.cuda.preferred_linalg_library('magma')
    k_target_actual = min(N, 2000) 
    real_eigenvalues_orig, eigenvectors = get_top_k_decomposition(h_holder, k=k_target_actual)
    
    del h_holder 
    torch.cuda.empty_cache()

    results = []
    
    # --- Baseline ---
    print("\n--- Processing Real Data (Original) ---")
    H_r_orig, C_r_orig = compute_H_C_only(real_eigenvalues_orig, eigenvectors)
    Gap_r_orig = compute_full_gap_matrix_efficient(real_eigenvalues_orig, eigenvectors)
    
    results.append({
        'H': H_r_orig, 'C': C_r_orig, 'Gap': Gap_r_orig, 
        'Spec': real_eigenvalues_orig, 'Title': "Original"
    })

    print("\n--- Processing Real Data (Cut 1) ---")
    real_eigenvalues_smooth2 = create_gap_from_real_spectrum(real_eigenvalues_orig, idx=1)
    H_r_smooth2, C_r_smooth2 = compute_H_C_only(real_eigenvalues_smooth2, eigenvectors)
    Gap_r_smooth2 = compute_full_gap_matrix_efficient(real_eigenvalues_smooth2, eigenvectors)
    
    results.append({
        'H': H_r_smooth2, 'C': C_r_smooth2, 'Gap': Gap_r_smooth2, 
        'Spec': real_eigenvalues_smooth2, 'Title': f"Cut @ k=1"
    })


    print("\n--- Processing Real Data (Reshaped) ---")

    real_eigenvalues_smooth3 = create_gap_from_real_spectrum(real_eigenvalues_orig, idx=1)
    real_eigenvalues_smooth3[:,0] = torch.mean(real_eigenvalues_smooth3[:,0], dim=0)

    
    H_r_smooth3, C_r_smooth3 = compute_H_C_only(real_eigenvalues_smooth3, eigenvectors)
    Gap_r_smooth3 = compute_full_gap_matrix_efficient(real_eigenvalues_smooth3, eigenvectors)
    
    results.append({
        'H': H_r_smooth3, 'C': C_r_smooth3, 'Gap': Gap_r_smooth3, 
        'Spec': real_eigenvalues_smooth3, 'Title': f"Cut @ k=1 and Collapse to the mean"
    })



    total_rows = len(results)
    print(f"\nGenerating Plots...")    
    fig, axes = plt.subplots(total_rows, 5, figsize=(20, 3.5 * total_rows))
    if total_rows == 1: 
        axes = np.expand_dims(axes, 0)


    for i, res in enumerate(results):
        analyze_slope_and_plot(CH_slice,
            res['H'], res['C'], res['Gap'], res['Spec'], res['Title'],
            H_r_orig, C_r_orig, 
            axes[i, 0], axes[i, 1], axes[i, 2], axes[i, 3], axes[i, 4],
            end_idx=k_target_actual)

    fig.suptitle(fig_title, fontsize=20, fontweight='bold', fontfamily='serif', y=1.02)
    

    plt.tight_layout()
    save_dir = f"ICML_Figures/{config['model']}_{config['dataset']}_{config['lss_fn']}"
    os.makedirs(save_dir, exist_ok=True)
    filename = f"Spectrum_cutoff.pdf"
    save_path = os.path.join(save_dir, filename)
    
    plt.savefig(
        save_path, 
        format='pdf',          
        bbox_inches='tight',   
        pad_inches=0.2,       
        dpi=300               
    ) 
    plt.show()
    print("Done.")

if __name__ == "__main__":

    if 'h_holder' in locals():
        MY_TITLE = "Effect of Spectrum Cutoff on C-H Scaling Law"
        CH_slice = torch.arange(0, 1500) 
        main(CH_slice, h_holder, fig_title=MY_TITLE)
    else:
        print("Wait: 'h_holder' is not defined. Please define input tensor first.") 

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import linregress
import matplotlib.ticker as ticker
import os

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def set_icml_appendix_style():

    plt.rcParams.update({
        'font.family': 'serif',
        'font.serif': ['Times New Roman', 'Times', 'DejaVu Serif'],
        'mathtext.fontset': 'stix', 
        
        'font.size': 18,           
        'axes.labelsize': 20,      
        'axes.titlesize': 22,      
        'xtick.labelsize': 16,     
        'ytick.labelsize': 16,
        'legend.fontsize': 16,     
        

        'lines.linewidth': 3.0,   
        'lines.markersize': 8,     
        'axes.linewidth': 2.0,   

        'xtick.major.size': 8,
        'xtick.major.width': 1.5,
        'ytick.major.size': 8,
        'ytick.major.width': 1.5,
        

        'axes.grid': True,
        'grid.alpha': 0.4,
        'grid.linestyle': '--',
        'grid.linewidth': 1.0,
        

        'savefig.dpi': 300,
        'pdf.fonttype': 42,
        'ps.fonttype': 42
    })

def format_boxed_ax(ax):

    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(2.0)  
        spine.set_color('black')
    ax.tick_params(direction='in', width=1.5, length=6, colors='black', grid_alpha=0.4)


def get_top_k_decomposition(h_holder, k=800):
    print(f"1. Performing Eigen Decomposition (Full Rank)...")
    h_holder = h_holder.to(device)
    vals, vecs = torch.linalg.eigh(h_holder)
    vals = torch.flip(vals, dims=[1])
    vecs = torch.flip(vecs, dims=[2])
    N = vals.shape[1]
    actual_k = min(k, N)
    print(f"2. Filtering Top {actual_k} modes...")
    vals_k = vals[:, :actual_k]
    vecs_k = vecs[:, :, :actual_k]
    del vals, vecs
    torch.cuda.empty_cache()
    return vals_k, vecs_k

def compute_H_C_only(eigenvalues, eigenvectors, chunk_size=16):
    N_samples = eigenvalues.shape[0]
    N_dim = eigenvectors.shape[1]
    H_mean_accumulator = np.zeros(N_dim, dtype=np.float32)
    C_mean_accumulator = np.zeros(N_dim, dtype=np.float32)

    with torch.no_grad():
        for i in range(0, N_samples, chunk_size):
            end_idx = min(i + chunk_size, N_samples)
            lam_chunk = eigenvalues[i:end_idx]
            vec_chunk = eigenvectors[i:end_idx]
            vec_chunk = torch.flip(vec_chunk, dims=[1]) 
            V_sq_chunk = vec_chunk ** 2
            
            H_batch_chunk = torch.matmul(V_sq_chunk, lam_chunk.unsqueeze(-1)).squeeze(-1)
            lam_sq_chunk = lam_chunk ** 2
            C_batch_chunk = torch.matmul(V_sq_chunk, lam_sq_chunk.unsqueeze(-1)).squeeze(-1)

            H_mean_accumulator += torch.sum(H_batch_chunk, dim=0).cpu().numpy()
            C_mean_accumulator += torch.sum(C_batch_chunk, dim=0).cpu().numpy()
            
            del lam_chunk, vec_chunk, V_sq_chunk, H_batch_chunk, C_batch_chunk
            
    return H_mean_accumulator / N_samples, C_mean_accumulator / N_samples

def compute_full_gap_matrix_efficient(eigenvalues, eigenvectors, batch_size=32):
    print(f"Computing Gap Matrix (Batch Size={batch_size})...")
    B, N, num_modes = eigenvectors.shape 
    eigenvectors = torch.flip(eigenvectors, dims=[1])
    total_gap_sum = torch.zeros((N, num_modes), device=device)
    with torch.no_grad():
        for b_start in range(0, B, batch_size):
            b_end = min(b_start + batch_size, B)
            vec_chunk = eigenvectors[b_start:b_end] 
            val_chunk = eigenvalues[b_start:b_end]      
            lam_sq = (val_chunk ** 2).unsqueeze(1)
            V_sq = vec_chunk ** 2
            numerator = torch.cumsum(lam_sq * V_sq, dim=2) 
            total_energy = numerator[:, :, -1].unsqueeze(2)
            denominator = total_energy 
            gap_chunk = numerator / denominator 
            total_gap_sum += torch.sum(gap_chunk, dim=0)
            del vec_chunk, val_chunk, lam_sq, V_sq, numerator, denominator, gap_chunk
            torch.cuda.empty_cache()
        
    avg_gap_matrix = total_gap_sum / B
    return avg_gap_matrix.cpu().numpy()

def create_gap_from_real_spectrum(eigenvalues, idx = 10):
    gap_spectrum = eigenvalues.clone()
    gap_spectrum[:, idx:] = gap_spectrum[:, idx:] * 1e-5
    return gap_spectrum



def plot_comparison_loglog(ax, data_orig, data_new, label_name, CH_slice):

    format_boxed_ax(ax) 
    x = data_orig[CH_slice]
    y = data_new[CH_slice]
    
    mask = (x > 0) & (y > 0)
    x_clean, y_clean = x[mask], y[mask]

    if len(x_clean) > 1:
        log_x, log_y = np.log10(x_clean), np.log10(y_clean)
        slope, intercept, r_val, _, _ = linregress(log_x, log_y)

        ax.scatter(x_clean, y_clean, s=3, facecolors='none', edgecolors='#1f77b4', 
                   alpha=0.7, linewidth=0.5, label='Data Points')
        

        fit_y = 10**(intercept + slope * log_x)
        ax.plot(x_clean, fit_y, color='#d62728', linestyle='-', linewidth=1.0, label=f'Fit: slope={slope:.2f}')

        min_val = min(x_clean.min(), y_clean.min())
        max_val = max(x_clean.max(), y_clean.max())
        ref_line = np.linspace(min_val, max_val, 100)
        ax.plot(ref_line, ref_line, 'k--', linewidth=2.0, alpha=0.6, label='y=x')
    
    ax.set_xscale('log')
    ax.set_yscale('log')
    
    ax.set_xlabel(f'Original {label_name}')
    ax.set_ylabel(f'Current {label_name}')
    ax.set_title(f'{label_name} Comparison')
    
    # 图例字体加大
    ax.legend(loc='lower right', frameon=True, framealpha=0.9, edgecolor='black', fancybox=False)


def analyze_slope_and_plot(CH_slice, H, C, gap_matrix, spectrum_data, row_title, 
                           H_ref, C_ref, 
                           ax_scatter, ax_gap, ax_spectrum, ax_comp_c, ax_comp_h,
                           end_idx=800):
    
    # --- 1. C_ii vs H_ii Scatter Plot ---
    format_boxed_ax(ax_scatter)
    x_data = H[CH_slice]
    y_data = C[CH_slice]
    mask = (x_data > 0) & (y_data > 0)
    x_data, y_data = x_data[mask], y_data[mask]
    
    if len(x_data) > 1:
        log_x, log_y = np.log10(x_data), np.log10(y_data)
        slope, intercept, _, _, _ = linregress(log_x, log_y)

        ax_scatter.scatter(x_data, y_data, s=3, alpha=0.6, color='#000080', label='Data', edgecolors='none')

        fit_line_y = 10**(intercept + slope * log_x)
        ax_scatter.plot(x_data, fit_line_y, color='#d62728', linestyle='-', linewidth=1.0, label=f'Fit: {slope:.2f}')

        mid_x, mid_y = np.mean(log_x), np.mean(log_y)
        ref_line_x = np.logspace(min(log_x), max(log_x), 100)
        
        ref_line_y1 = 10**(1.0 * (np.log10(ref_line_x) - mid_x) + mid_y)
        ax_scatter.plot(ref_line_x, ref_line_y1, 'k--', linewidth=2.0, label='Slope=1')
        
        ref_line_y2 = 10**(2.0 * (np.log10(ref_line_x) - mid_x) + mid_y)
        ax_scatter.plot(ref_line_x, ref_line_y2, color='green', linestyle=':', linewidth=3.0, label='Slope=2')
    
    ax_scatter.set_xscale('log')
    ax_scatter.set_yscale('log')
    ax_scatter.set_xlabel(r'$H_{ii}$')
    ax_scatter.set_ylabel(r'$C_{ii}$')
    ax_scatter.set_title(f"Scalability ({row_title})", fontweight='bold') 
    ax_scatter.legend(loc='upper left', frameon=True, edgecolor='black', fancybox=False)
    
    # --- 2. Gap Curve Plot ---
    format_boxed_ax(ax_gap)
    real_k = gap_matrix.shape[1]
    k_values = np.arange(1, real_k + 1)
    avg_gap_curve = np.mean(gap_matrix[:1500], axis=0)
    
    indices_to_plot = [1, 10, 50, 100, 400, 1000]
    colors_gap = plt.cm.viridis(np.linspace(0, 0.9, len(indices_to_plot))) 
    
    for idx, color in zip(indices_to_plot, colors_gap):
        if idx < gap_matrix.shape[0]:
            ax_gap.semilogx(k_values, gap_matrix[idx, :], color=color, linewidth=2.0, alpha=0.8, label=f'i={idx+1}')
            
    ax_gap.semilogx(k_values, avg_gap_curve, 'k--', linewidth=3.5, alpha=1.0, label='Avg')

    ax_gap.legend(loc='lower right', ncol=2, frameon=True, edgecolor='black', fancybox=False) 
    ax_gap.set_title("Gap Metric")
    ax_gap.set_xlabel('k (Cutoff)')
    ax_gap.set_ylabel('Cum. Distribution')
    ax_gap.set_ylim(-0.05, 1.1)
    

    format_boxed_ax(ax_spectrum)
    if isinstance(spectrum_data, torch.Tensor):
        spec_np = spectrum_data.detach().cpu().numpy()
    else:
        spec_np = spectrum_data
        
    B_samples = spec_np.shape[0]
    num_classes = 8 
    colors = plt.cm.Set1(np.linspace(0, 1, num_classes))
    ranks = np.arange(1, spec_np.shape[1] + 1)
    
    selected_indices = np.linspace(0, B_samples-1, num_classes, dtype=int)
    
    for i, idx in enumerate(selected_indices):
        ax_spectrum.loglog(ranks, spec_np[idx], color=colors[i], alpha=0.8, linewidth=2.0)
            
    ax_spectrum.axvline(x=10, color='gray', linestyle=':', linewidth=2.5)
    ax_spectrum.set_title('Spectrum (Sampled)')
    ax_spectrum.set_xlabel('Rank')
    ax_spectrum.set_ylabel(r'$|\lambda_i|$')


    plot_comparison_loglog(ax_comp_c, C_ref, C, r'$C_{ii}$', CH_slice)


    plot_comparison_loglog(ax_comp_h, H_ref, H, r'$H_{ii}$', CH_slice)



def main(CH_slice, h_holder_input=None, fig_title="Neural Collapse Spectral Analysis"):

    set_icml_appendix_style()
    
    h_holder = h_holder_input
    B, N, _ = h_holder.shape
    print(f"Processing input data: B={B}, N={N}")


    torch.backends.cuda.preferred_linalg_library('magma')
    k_target_actual = min(N, 2000) 
    real_eigenvalues_orig, eigenvectors = get_top_k_decomposition(h_holder, k=k_target_actual)
    
    del h_holder 
    torch.cuda.empty_cache()

    results = []
    
    # --- Baseline ---
    print("\n--- Processing Real Data (Original) ---")
    H_r_orig, C_r_orig = compute_H_C_only(real_eigenvalues_orig, eigenvectors)
    Gap_r_orig = compute_full_gap_matrix_efficient(real_eigenvalues_orig, eigenvectors)
    
    results.append({
        'H': H_r_orig, 'C': C_r_orig, 'Gap': Gap_r_orig, 
        'Spec': real_eigenvalues_orig, 'Title': "Original"
    })
    
    # --- Modified 2 (Gap cut at 1) ---
    print("\n--- Processing Real Data (Cut 1) ---")
    real_eigenvalues_smooth2 = create_gap_from_real_spectrum(real_eigenvalues_orig, idx=1)
    H_r_smooth2, C_r_smooth2 = compute_H_C_only(real_eigenvalues_smooth2, eigenvectors)
    Gap_r_smooth2 = compute_full_gap_matrix_efficient(real_eigenvalues_smooth2, eigenvectors)
    
    results.append({
        'H': H_r_smooth2, 'C': C_r_smooth2, 'Gap': Gap_r_smooth2, 
        'Spec': real_eigenvalues_smooth2, 'Title': f"Cut @ k=1"
    })

    # --- Real Data (Reshaped) ---
    print("\n--- Processing Real Data (Reshaped) ---")
    real_eigenvalues_smooth3 = create_gap_from_real_spectrum(real_eigenvalues_orig, idx=1)
    real_eigenvalues_smooth3[:,0] = torch.mean(real_eigenvalues_smooth3[:,0], dim=0)

    H_r_smooth3, C_r_smooth3 = compute_H_C_only(real_eigenvalues_smooth3, eigenvectors)
    Gap_r_smooth3 = compute_full_gap_matrix_efficient(real_eigenvalues_smooth3, eigenvectors)
    
    results.append({
        'H': H_r_smooth3, 'C': C_r_smooth3, 'Gap': Gap_r_smooth3, 
        'Spec': real_eigenvalues_smooth3, 'Title': f"Cut @ k=1 & Mean Collapse"
    })


    # ==========================================
    # 3. Plotting with Overall Title
    # ==========================================
    total_rows = len(results)
    print(f"\nGenerating Plots...")

    fig, axes = plt.subplots(total_rows, 5, figsize=(24, 5.0 * total_rows), constrained_layout=True)
    
    if total_rows == 1: 
        axes = np.expand_dims(axes, 0)

    for i, res in enumerate(results):
        analyze_slope_and_plot(CH_slice,
            res['H'], res['C'], res['Gap'], res['Spec'], res['Title'],
            H_r_orig, C_r_orig, 
            axes[i, 0], axes[i, 1], axes[i, 2], axes[i, 3], axes[i, 4],
            end_idx=k_target_actual)
    

    fig.suptitle(fig_title, fontsize=28, fontweight='bold', fontfamily='serif')

    try:
        save_dir = f"ICML_Figures/{config['model']}_{config['dataset']}_{config['lss_fn']}"
    except:
        save_dir = "ICML_Figures/Debug"
        
    os.makedirs(save_dir, exist_ok=True)
    filename = f"Spectrum_cutoff_appendix.pdf"
    save_path = os.path.join(save_dir, filename)
    
    print(f"Saving to {save_path} ...")
    plt.savefig(
        save_path, 
        format='pdf', 
        bbox_inches='tight', 
        pad_inches=0.2, 
        dpi=300 
    ) 
    plt.show()
    print("Done.")

if __name__ == "__main__":

    if 'h_holder' in locals():

        MY_TITLE = "Effect of Spectrum Cutoff on C-H Scaling Law"
        

        CH_slice = torch.arange(0, 1500) 
        main(CH_slice, h_holder, fig_title=MY_TITLE)
    else:
        print("Wait: 'h_holder' is not defined. Please define input tensor first.") 


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import linregress
import matplotlib.ticker as ticker
import os

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def set_icml_appendix_style():

    plt.rcParams.update({

        'font.family': 'serif',
        'font.serif': ['Times New Roman', 'Times', 'DejaVu Serif'],
        'mathtext.fontset': 'stix', 
        
        'font.size': 18,
        'axes.labelsize': 20,
        'axes.titlesize': 22,
        'xtick.labelsize': 16,
        'ytick.labelsize': 16,
        'legend.fontsize': 16,
        

        'lines.linewidth': 3.0,
        'lines.markersize': 8,
        'axes.linewidth': 2.0,
        

        'xtick.major.size': 8,
        'xtick.major.width': 1.5,
        'ytick.major.size': 8,
        'ytick.major.width': 1.5,

        'axes.grid': True,
        'grid.alpha': 0.4,
        'grid.linestyle': '--',
        'grid.linewidth': 1.0,

        'savefig.dpi': 300,
        'pdf.fonttype': 42,
        'ps.fonttype': 42
    })

def format_boxed_ax(ax):

    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(2.0)
        spine.set_color('black')
    ax.tick_params(direction='in', width=1.5, length=6, colors='black', grid_alpha=0.4)



def get_top_k_decomposition(h_holder, k=800):
    print(f"1. Performing Eigen Decomposition (Full Rank)...")
    h_holder = h_holder.to(device)
    vals, vecs = torch.linalg.eigh(h_holder)
    vals = torch.flip(vals, dims=[1])
    vecs = torch.flip(vecs, dims=[2])
    N = vals.shape[1]
    actual_k = min(k, N)
    print(f"2. Filtering Top {actual_k} modes...")
    vals_k = vals[:, :actual_k]
    vecs_k = vecs[:, :, :actual_k]
    del vals, vecs
    torch.cuda.empty_cache()
    return vals_k, vecs_k

def compute_H_C_only(eigenvalues, eigenvectors, chunk_size=16):
    N_samples = eigenvalues.shape[0]
    N_dim = eigenvectors.shape[1]
    H_mean_accumulator = np.zeros(N_dim, dtype=np.float32)
    C_mean_accumulator = np.zeros(N_dim, dtype=np.float32)

    with torch.no_grad():
        for i in range(0, N_samples, chunk_size):
            end_idx = min(i + chunk_size, N_samples)
            lam_chunk = eigenvalues[i:end_idx]
            vec_chunk = eigenvectors[i:end_idx]
            vec_chunk = torch.flip(vec_chunk, dims=[1]) 
            V_sq_chunk = vec_chunk ** 2
            
            H_batch_chunk = torch.matmul(V_sq_chunk, lam_chunk.unsqueeze(-1)).squeeze(-1)
            lam_sq_chunk = lam_chunk ** 2
            C_batch_chunk = torch.matmul(V_sq_chunk, lam_sq_chunk.unsqueeze(-1)).squeeze(-1)

            H_mean_accumulator += torch.sum(H_batch_chunk, dim=0).cpu().numpy()
            C_mean_accumulator += torch.sum(C_batch_chunk, dim=0).cpu().numpy()
            
            del lam_chunk, vec_chunk, V_sq_chunk, H_batch_chunk, C_batch_chunk
            
    return H_mean_accumulator / N_samples, C_mean_accumulator / N_samples

# def compute_full_gap_matrix_efficient(eigenvalues, eigenvectors, batch_size=32):
#     print(f"Computing Gap Matrix (Batch Size={batch_size})...")
#     B, N, num_modes = eigenvectors.shape 
#     eigenvectors = torch.flip(eigenvectors, dims=[1])
#     total_gap_sum = torch.zeros((N, num_modes), device=device)
#     with torch.no_grad():
#         for b_start in range(0, B, batch_size):
#             b_end = min(b_start + batch_size, B)
#             vec_chunk = eigenvectors[b_start:b_end] 
#             val_chunk = eigenvalues[b_start:b_end]      
#             lam_sq = (val_chunk ** 2).unsqueeze(1)
#             V_sq = vec_chunk ** 2
#             numerator = torch.cumsum(lam_sq * V_sq, dim=2) 
#             total_energy = numerator[:, :, -1].unsqueeze(2)
#             denominator = total_energy 
#             gap_chunk = numerator / denominator 
#             total_gap_sum += torch.sum(gap_chunk, dim=0)
#             del vec_chunk, val_chunk, lam_sq, V_sq, numerator, denominator, gap_chunk
#             torch.cuda.empty_cache()
        
#     avg_gap_matrix = total_gap_sum / B
#     return avg_gap_matrix.cpu().numpy()


def compute_full_gap_matrix_efficient(eigenvalues, eigenvectors, batch_size=32):
    print(f"Computing Gap Matrix (High Precision, Batch Size={batch_size})...")
    B, N, num_modes = eigenvectors.shape 
    eigenvectors = torch.flip(eigenvectors, dims=[1])
    
    total_gap_sum = torch.zeros((N, num_modes), device=device, dtype=torch.float64)
    
    with torch.no_grad():
        for b_start in range(0, B, batch_size):
            b_end = min(b_start + batch_size, B)
            
            vec_chunk = eigenvectors[b_start:b_end].to(torch.float64) 
            val_chunk = eigenvalues[b_start:b_end].to(torch.float64)
            
            lam_sq = (val_chunk ** 2).unsqueeze(1)
            V_sq = vec_chunk ** 2
            
            energy_contribution = lam_sq * V_sq
            numerator = torch.cumsum(energy_contribution, dim=2) 
            
            total_energy = numerator[:, :, -1].unsqueeze(2)
            
            gap_chunk = numerator / (total_energy) 
            
            total_gap_sum += torch.sum(gap_chunk, dim=0)
            
            del vec_chunk, val_chunk, lam_sq, V_sq, numerator, total_energy, gap_chunk
            torch.cuda.empty_cache()
        
    avg_gap_matrix = total_gap_sum / B
    return avg_gap_matrix.cpu().numpy().astype(np.float32)




def create_gap_from_real_spectrum(eigenvalues, idx = 10):

    gap_spectrum = eigenvalues.clone()
    gap_spectrum[:, idx:] = gap_spectrum[:, idx:] * 1e-5
    return gap_spectrum

def create_stiff_mean_spectrum(eigenvalues, idx=1, threshold_ratio=0.1, meancollapse=False):

    new_spectrum = eigenvalues.clone()
    new_spectrum[:, idx:] = new_spectrum[:, idx:] * 1e-5
    

    lam_0 = new_spectrum[:, 0]

    threshold = torch.max(lam_0) * threshold_ratio
    
    stiff_mask = lam_0 > threshold
    non_stiff_mask = ~stiff_mask

    if stiff_mask.sum() > 0 and meancollapse:
        stiff_mean = torch.mean(lam_0[stiff_mask])
        print(f"   [Stiff Analysis] Threshold: {threshold:.4f}, Count: {stiff_mask.sum().item()}/{len(lam_0)}, Mean Val: {stiff_mean:.4f}")

        new_spectrum[stiff_mask, 0] = stiff_mean
    if non_stiff_mask.sum() > 0:

        print(f"   [Stiff Analysis] Suppressing {non_stiff_mask.sum().item()} Non-Stiff samples by 1e-5.")
        new_spectrum[non_stiff_mask, 0] = new_spectrum[non_stiff_mask, 0] * 1e-3
        
    return new_spectrum



def plot_comparison_loglog(ax, data_orig, data_new, label_name, CH_slice):
    format_boxed_ax(ax)
    x = data_orig[CH_slice]
    y = data_new[CH_slice]
    
    mask = (x > 0) & (y > 0)
    x_clean, y_clean = x[mask], y[mask]

    if len(x_clean) > 1:
        log_x, log_y = np.log10(x_clean), np.log10(y_clean)
        slope, intercept, r_val, _, _ = linregress(log_x, log_y)
        
        ax.scatter(x_clean, y_clean, s=3, facecolors='none', edgecolors='#1f77b4', 
                   alpha=0.7, linewidth=0.5, label='Data Points')
        
        fit_y = 10**(intercept + slope * log_x)
        ax.plot(x_clean, fit_y, color='#d62728', linestyle='-', linewidth=1.0, label=f'Fit: slope={slope:.2f}')
        
        min_val = min(x_clean.min(), y_clean.min())
        max_val = max(x_clean.max(), y_clean.max())
        ref_line = np.linspace(min_val, max_val, 100)
        ax.plot(ref_line, ref_line, 'k--', linewidth=2.0, alpha=0.6, label='y=x')
    
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlabel(f'Original {label_name}')
    ax.set_ylabel(f'Current {label_name}')
    ax.set_title(f'{label_name} Comparison')
    ax.legend(loc='lower right', frameon=True, framealpha=0.9, edgecolor='black', fancybox=False)


def analyze_slope_and_plot(CH_slice, H, C, gap_matrix, spectrum_data, row_title, 
                           H_ref, C_ref, 
                           ax_scatter, ax_gap, ax_spectrum, ax_comp_c, ax_comp_h,
                           end_idx=800):
    
    # --- 1. C_ii vs H_ii Scatter ---
    format_boxed_ax(ax_scatter)
    x_data = H[CH_slice]
    y_data = C[CH_slice]
    mask = (x_data > 0) & (y_data > 0)
    x_data, y_data = x_data[mask], y_data[mask]
    
    if len(x_data) > 1:
        log_x, log_y = np.log10(x_data), np.log10(y_data)
        slope, intercept, _, _, _ = linregress(log_x, log_y)
        ax_scatter.scatter(x_data, y_data, s=3, alpha=0.6, color='#000080', label='Data', edgecolors='none')
        fit_line_y = 10**(intercept + slope * log_x)
        ax_scatter.plot(x_data, fit_line_y, color='#d62728', linestyle='-', linewidth=1.0, label=f'Fit: {slope:.2f}')
        
        mid_x, mid_y = np.mean(log_x), np.mean(log_y)
        ref_line_x = np.logspace(min(log_x), max(log_x), 100)
        ref_line_y1 = 10**(1.0 * (np.log10(ref_line_x) - mid_x) + mid_y)
        ax_scatter.plot(ref_line_x, ref_line_y1, 'k--', linewidth=2.0, label='Slope=1')
        ref_line_y2 = 10**(2.0 * (np.log10(ref_line_x) - mid_x) + mid_y)
        ax_scatter.plot(ref_line_x, ref_line_y2, color='green', linestyle=':', linewidth=3.0, label='Slope=2')
    
    ax_scatter.set_xscale('log')
    ax_scatter.set_yscale('log')
    ax_scatter.set_xlabel(r'$H_{ii}$')
    ax_scatter.set_ylabel(r'$C_{ii}$')
    ax_scatter.set_title(f"Scalability ({row_title})", fontweight='bold') 
    ax_scatter.legend(loc='upper left', frameon=True, edgecolor='black', fancybox=False)
    
    # --- 2. Gap Curve ---
    format_boxed_ax(ax_gap)
    real_k = gap_matrix.shape[1]
    k_values = np.arange(1, real_k + 1)
    avg_gap_curve = np.mean(gap_matrix[:1500], axis=0)
    
    indices_to_plot = [1, 10, 50, 100, 400, 1000]
    colors_gap = plt.cm.viridis(np.linspace(0, 0.9, len(indices_to_plot))) 
    for idx, color in zip(indices_to_plot, colors_gap):
        if idx < gap_matrix.shape[0]:
            ax_gap.semilogx(k_values, gap_matrix[idx, :], color=color, linewidth=2.0, alpha=0.8, label=f'i={idx+1}')
    ax_gap.semilogx(k_values, avg_gap_curve, 'k--', linewidth=3.5, alpha=1.0, label='Avg')
    ax_gap.legend(loc='lower right', ncol=2, frameon=True, edgecolor='black', fancybox=False) 
    ax_gap.set_title("Gap Metric")
    ax_gap.set_xlabel('k (Cutoff)')
    ax_gap.set_ylabel('Cum. Distribution')
    ax_gap.set_ylim(-0.05, 1.1)
    
    # --- 3. Spectrum Plot ---
    format_boxed_ax(ax_spectrum)
    if isinstance(spectrum_data, torch.Tensor):
        spec_np = spectrum_data.detach().cpu().numpy()
    else:
        spec_np = spectrum_data
    B_samples = spec_np.shape[0]
    num_classes = 8
    colors = plt.cm.Set1(np.linspace(0, 1, num_classes))
    ranks = np.arange(1, spec_np.shape[1] + 1)
    selected_indices = np.linspace(0, B_samples-1, num_classes, dtype=int)
    for i, idx in enumerate(selected_indices):
        ax_spectrum.loglog(ranks, spec_np[idx], color=colors[i], alpha=0.8, linewidth=2.0)
    ax_spectrum.axvline(x=10, color='gray', linestyle=':', linewidth=2.5)
    ax_spectrum.set_title('Spectrum (Sampled)')
    ax_spectrum.set_xlabel('Rank')
    ax_spectrum.set_ylabel(r'$|\lambda_i|$')

    # --- 4. Comparison C ---
    plot_comparison_loglog(ax_comp_c, C_ref, C, r'$C_{ii}$', CH_slice)

    # --- 5. Comparison H ---
    plot_comparison_loglog(ax_comp_h, H_ref, H, r'$H_{ii}$', CH_slice)


def main(CH_slice, h_holder_input=None, fig_title="Neural Collapse Spectral Analysis"):
    set_icml_appendix_style()
    h_holder = h_holder_input
    B, N, _ = h_holder.shape
    print(f"Processing input data: B={B}, N={N}")

    # Baseline: Real Data
    torch.backends.cuda.preferred_linalg_library('magma')
    k_target_actual = min(N, 2000) 
    real_eigenvalues_orig, eigenvectors = get_top_k_decomposition(h_holder, k=k_target_actual)
    del h_holder 
    torch.cuda.empty_cache()

    results = []
    
    # 1. Original
    print("\n--- Processing Real Data (Original) ---")
    H_r_orig, C_r_orig = compute_H_C_only(real_eigenvalues_orig, eigenvectors)
    Gap_r_orig = compute_full_gap_matrix_efficient(real_eigenvalues_orig, eigenvectors)
    results.append({'H': H_r_orig, 'C': C_r_orig, 'Gap': Gap_r_orig, 'Spec': real_eigenvalues_orig, 'Title': "Original"})
    
    # 2. Cut @ k=1 (Tail Compression)
    print("\n--- Processing Real Data (Cut 1) ---")
    real_eigenvalues_smooth2 = create_gap_from_real_spectrum(real_eigenvalues_orig, idx=1)
    H_r_smooth2, C_r_smooth2 = compute_H_C_only(real_eigenvalues_smooth2, eigenvectors)
    Gap_r_smooth2 = compute_full_gap_matrix_efficient(real_eigenvalues_smooth2, eigenvectors)
    results.append({'H': H_r_smooth2, 'C': C_r_smooth2, 'Gap': Gap_r_smooth2, 'Spec': real_eigenvalues_smooth2, 'Title': f"Cut @ k=1"})

    # 3. Cut @ k=1 & Stiff Mean Collapse (MODIFIED)
    print("\n--- Processing Real Data (Stiff Mean Collapse) ---")

    real_eigenvalues_smooth3 = create_stiff_mean_spectrum(real_eigenvalues_orig, idx=1, threshold_ratio=0.001, meancollapse=False)
    
    H_r_smooth3, C_r_smooth3 = compute_H_C_only(real_eigenvalues_smooth3, eigenvectors)
    Gap_r_smooth3 = compute_full_gap_matrix_efficient(real_eigenvalues_smooth3, eigenvectors)
    results.append({'H': H_r_smooth3, 'C': C_r_smooth3, 'Gap': Gap_r_smooth3, 'Spec': real_eigenvalues_smooth3, 'Title': f"Cut @ k=1 & Stiff Preserve"})


    # 4. Cut @ k=10 & Stiff Mean Collapse (MODIFIED)
    print("\n--- Processing Real Data (Cut 10 & Stiff Mean Collapse) ---")
    real_eigenvalues_smooth4 = create_stiff_mean_spectrum(real_eigenvalues_orig, idx=1, threshold_ratio=0.001,meancollapse=True)
    H_r_smooth4, C_r_smooth4 = compute_H_C_only(real_eigenvalues_smooth4, eigenvectors)
    Gap_r_smooth4 = compute_full_gap_matrix_efficient(real_eigenvalues_smooth4, eigenvectors)
    results.append({'H': H_r_smooth4, 'C': C_r_smooth4, 'Gap': Gap_r_smooth4, 'Spec': real_eigenvalues_smooth4, 'Title': f"Cut @ k=1 & Stiff Mean Collapse"})

    # Plotting
    total_rows = len(results)
    print(f"\nGenerating Plots...")
    fig, axes = plt.subplots(total_rows, 5, figsize=(24, 5.0 * total_rows), constrained_layout=True)
    if total_rows == 1: axes = np.expand_dims(axes, 0)

    for i, res in enumerate(results):
        analyze_slope_and_plot(CH_slice,
            res['H'], res['C'], res['Gap'], res['Spec'], res['Title'],
            H_r_orig, C_r_orig, 
            axes[i, 0], axes[i, 1], axes[i, 2], axes[i, 3], axes[i, 4],
            end_idx=k_target_actual)
    
    fig.suptitle(fig_title, fontsize=28, fontweight='bold', fontfamily='serif')
    

    save_dir = "ICML_Figures/Ablation"
    os.makedirs(save_dir, exist_ok=True)
    filename = f"Spectrum_stiff_mean_ablation.pdf"
    save_path = os.path.join(save_dir, filename)
    print(f"Saving to {save_path} ...")
    plt.savefig(save_path, format='pdf', bbox_inches='tight', pad_inches=0.2, dpi=300) 
    plt.show()
    print("Done.")

if __name__ == "__main__":
    if 'h_holder' in locals():
        MY_TITLE = "Effect of Stiff Sample Homogenization"
        CH_slice = torch.arange(0, 1500)
        main(CH_slice, h_holder, fig_title=MY_TITLE)
    else:
        print("Wait: 'h_holder' is not defined.")