In [None]:
import pickle
import os
from collections import defaultdict
import matplotlib.pyplot as plt

# Auto-load all metrics data
def load_all_metrics(metrics_dir='./metrics'):
    """
    Automatically load all .pkl files from metrics directory.
    Returns nested dict: data[dataset][model][mode][gamma] = metrics_dict
    """
    data = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
    
    for dataset in os.listdir(metrics_dir):
        dataset_path = os.path.join(metrics_dir, dataset)
        if not os.path.isdir(dataset_path):
            continue
        
        for mode in os.listdir(dataset_path):
            mode_path = os.path.join(dataset_path, mode)
            if not os.path.isdir(mode_path):
                continue
            
            for filename in os.listdir(mode_path):
                if not filename.endswith('.pkl'):
                    continue
                
                # Parse filename: mode:model:gamma.pkl
                parts = filename.replace('.pkl', '').split(':')
                if len(parts) != 3:
                    continue
                
                _, model, gamma = parts
                
                filepath = os.path.join(mode_path, filename)
                with open(filepath, 'rb') as f:
                    metrics = pickle.load(f)
                
                # Convert keys to strings for consistency
                metrics = {str(k): v for k, v in metrics.items()}
                data[dataset][model][mode][gamma] = metrics
    
    return data

# Load everything
ALL_DATA = load_all_metrics()

# Show what's available
print('Available datasets:', list(ALL_DATA.keys()))
for ds in ALL_DATA:
    print(f'  {ds} models:', list(ALL_DATA[ds].keys()))
    for model in ALL_DATA[ds]:
        print(f'    {model} modes:', list(ALL_DATA[ds][model].keys()))

In [None]:
# ============================================
# CONFIGURE YOUR ANALYSIS HERE
# ============================================

DATASET = 'cifar100'
MODEL = 'vit_tiny_patch16_224'

# Pick which methods to compare
METHODS = [
    'L1',
    'spatial',
    'spatial-swap',
    'block-4',
    'block-16',
]

# Filter to only methods that exist in the data
available_methods = list(ALL_DATA[DATASET][MODEL].keys())
METHODS = [m for m in METHODS if m in available_methods]
print(f'Using methods: {METHODS}')
print(f'Available but not selected: {set(available_methods) - set(METHODS)}')

In [None]:
# ============================================
# CHART A: Sparsity vs Accuracy at Fixed Thresholds
# One chart per threshold (0.01, 0.001, 0.0001)
# ============================================

thresholds = ['0.01', '0.001', '0.0001']

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for ax, threshold in zip(axes, thresholds):
    for method in METHODS:
        points = []
        for gamma, metrics in ALL_DATA[DATASET][MODEL][method].items():
            if threshold in metrics:
                acc = metrics[threshold].get('final_acc')
                sparsity = metrics[threshold].get('percent_below_t')
                if acc is not None and sparsity is not None:
                    points.append((sparsity, acc, gamma))
        
        if points:
            points.sort(key=lambda x: x[0])  # Sort by sparsity
            x = [p[0] for p in points]
            y = [p[1] for p in points]
            ax.plot(x, y, 'o-', label=method, markersize=6)
    
    ax.set_xlabel('Sparsity (%)')
    ax.set_ylabel('Accuracy (%)')
    ax.set_title(f'Threshold = {threshold}')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.suptitle(f'{DATASET} / {MODEL} - Accuracy vs Sparsity at Fixed Thresholds', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# ============================================
# CHART B: Best Accuracy per Sparsity Level
# For each fixed sparsity level, find the best gamma for each method
# ============================================

sparsity_levels = [100, 90, 80, 70, 60, 50, 40, 30, 20, 10, 5, 3, 2, 1]

fig, ax = plt.subplots(figsize=(10, 6))

for method in METHODS:
    best_points = []
    
    for p in sparsity_levels:
        p_key = str(p)
        best_acc = None
        
        # Find the best gamma for this sparsity level
        for gamma, metrics in ALL_DATA[DATASET][MODEL][method].items():
            if p_key in metrics:
                acc = metrics[p_key].get('final_acc')
                if acc is not None:
                    if best_acc is None or acc > best_acc:
                        best_acc = acc
        
        if best_acc is not None:
            # x-axis: 100 - p = percent pruned
            best_points.append((100 - p, best_acc))
    
    if best_points:
        best_points.sort(key=lambda x: x[0])
        x = [p[0] for p in best_points]
        y = [p[1] for p in best_points]
        ax.plot(x, y, 'o-', label=method, markersize=8)

ax.set_xlabel('Sparsity (% weights pruned)')
ax.set_ylabel('Best Accuracy (%)')
ax.set_title(f'{DATASET} / {MODEL} - Best Accuracy at Each Sparsity Level\n(Best gamma selected per method per sparsity level)')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# ============================================
# CHART C: Modularity vs Accuracy
# CHART D: Modularity vs Sparsity
# Using threshold 0.001
# ============================================

threshold = '0.001'

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Chart C: Modularity vs Accuracy
ax = axes[0]
for method in METHODS:
    points = []
    for gamma, metrics in ALL_DATA[DATASET][MODEL][method].items():
        if threshold in metrics:
            acc = metrics[threshold].get('final_acc')
            modularity = metrics[threshold].get('modularity')
            if acc is not None and modularity is not None:
                points.append((modularity, acc))
    
    if points:
        points.sort(key=lambda x: x[0])
        x = [p[0] for p in points]
        y = [p[1] for p in points]
        ax.plot(x, y, 'o-', label=method, markersize=6)

ax.set_xlabel('Modularity (Q)')
ax.set_ylabel('Accuracy (%)')
ax.set_title(f'Modularity vs Accuracy (threshold={threshold})')
ax.legend()
ax.grid(True, alpha=0.3)

# Chart D: Modularity vs Sparsity
ax = axes[1]
for method in METHODS:
    points = []
    for gamma, metrics in ALL_DATA[DATASET][MODEL][method].items():
        if threshold in metrics:
            sparsity = metrics[threshold].get('percent_below_t')
            modularity = metrics[threshold].get('modularity')
            if sparsity is not None and modularity is not None:
                points.append((modularity, sparsity))
    
    if points:
        points.sort(key=lambda x: x[0])
        x = [p[0] for p in points]
        y = [p[1] for p in points]
        ax.plot(x, y, 'o-', label=method, markersize=6)

ax.set_xlabel('Modularity (Q)')
ax.set_ylabel('Sparsity (%)')
ax.set_title(f'Modularity vs Sparsity (threshold={threshold})')
ax.legend()
ax.grid(True, alpha=0.3)

plt.suptitle(f'{DATASET} / {MODEL}', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# ============================================
# CHART E: Block Sparsity at Same Accuracy Level
# For each method, find the best block sparsity achievable at each accuracy level
# ============================================

accuracy_levels = [80, 75, 70, 65, 60, 55, 50]

fig, ax = plt.subplots(figsize=(10, 6))

for method in METHODS:
    points = []
    
    # Collect all (accuracy, block_sparsity) pairs across all gammas and sparsity levels
    all_pairs = []
    for gamma, metrics in ALL_DATA[DATASET][MODEL][method].items():
        for key, m in metrics.items():
            if isinstance(m, dict):
                acc = m.get('final_acc')
                bs = m.get('block_sparsity_reordered')
                if acc is not None and bs is not None:
                    all_pairs.append((acc, bs * 100))
    
    # For each accuracy level, find the best block sparsity among runs with acc >= that level
    for target_acc in accuracy_levels:
        # Find all pairs with accuracy >= target
        valid = [(acc, bs) for acc, bs in all_pairs if acc >= target_acc]
        if valid:
            # Pick the one with highest block sparsity
            best = max(valid, key=lambda x: x[1])
            points.append((target_acc, best[1]))
    
    if points:
        x = [p[0] for p in points]
        y = [p[1] for p in points]
        ax.plot(x, y, 'o-', label=method, markersize=8)

ax.set_xlabel('Minimum Accuracy (%)')
ax.set_ylabel('Best Block Sparsity Achieved (%)')
ax.set_title(f'{DATASET} / {MODEL} - Block Sparsity at Same Accuracy Level\n(For each accuracy threshold, best block sparsity among runs with acc >= threshold)')
ax.legend()
ax.grid(True, alpha=0.3)
ax.invert_xaxis()  # Higher accuracy on left
plt.tight_layout()
plt.show()