In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import matplotlib.pyplot as plt

import os
import utils.plot_finals as finals
from typing import Dict, Any

In [None]:
plt.rcParams.update({
        # Use a serif font that's likely available
        'font.family': 'serif',
        'font.serif': ['DejaVu Serif', 'Liberation Serif', 'Computer Modern Roman', 'Bitstream Vera Serif'],
        'font.size': 12,
        'axes.labelsize': 14,
        'axes.titlesize': 16,
        'xtick.labelsize': 12,
        'ytick.labelsize': 12,
        'legend.fontsize': 12,
        'figure.dpi': 300,
        'savefig.dpi': 600,  # Higher DPI for publication quality
        'savefig.format': 'pdf',  # PDF format is often preferred for publications
        'savefig.bbox': 'tight',
        'savefig.pad_inches': 0.1,
        'axes.linewidth': 0.8,  # Slightly thinner axes lines
        'lines.linewidth': 1.5,  # Slightly thicker plot lines
        'lines.markersize': 4,  # Slightly smaller markers
        'axes.grid': True,
        'grid.alpha': 0.3
    })

## PointWise net + Showerflow

In [None]:
# Inizializza la lista con 10k
training_steps=(
    10000,
    50000,
    100000,
    150000, 
    200000, 
    250000,
    500000,
    750000, 
    1000000, 
)
print(training_steps)
print("len(training_steps):", len(training_steps))

In [None]:
grad_step = training_steps  # or use: np.array([...]).astype(int)

# Base configuration paths
BASE_DIR = '/data/dust/user/valentel/maxwell.merged/MyCaloTransfer/CaloTransfer/results/diffusion/'

# Define checkpoint patterns - only the parts that differ
CHECKPOINT_PATTERNS = {
    ###### 10-90 GeV ######
    # Vanilla models
    'vanilla':    '2025_02_04__18_52_56',
    'vanilla_v1': '2025_03_14__12_48_20',
    'vanilla_v2': '2025_03_13__10_48_14',
    'vanilla_v3': '2025_03_14__05_16_47',
    'vanilla_v4': '2025_03_29__11_12_02',
    'vanilla_v5': '2025_03_30__13_17_42',
    'vanilla_v6': '2025_03_31__14_22_22',
    
    # Full finetune models
    'full_finetune':    '2025_02_04__18_15_56',
    'full_finetune_v1': '2025_03_16__11_00_43',
    'full_finetune_v2': '2025_03_17__13_49_51',
    'full_finetune_v3': '2025_03_18__16_36_36',
    
    # 3-layer frozen finetune
    '3frozen_finetune_v6': '2025_03_09__19_00_21',
    '3frozen_finetune_v7': '2025_03_10__18_20_44',
    '3frozen_finetune_v8': '2025_03_11__17_41_36',
    
    # LoRA models
    'lora_full_v1': '2025_03_19__23_39_27',
    'lora_full_v2': '2025_03_21__10_53_31',
    'lora_full_v3': '2025_03_23__13_07_39',
    
    ###### 1-1000 GeV ######
    # 1-1000 GeV models
    'vanilla_full_v1_1_1000': '2025_05_28__15_44_22',
    'vanilla_full_v2_1_1000': '2025_05_08__16_28_38',
    'vanilla_full_v3_1_1000': '2025_05_12__10_56_50',
    'vanilla_full_v4_1_1000': '2025_08_18__09_32_53',
    'vanilla_full_v5_1_1000': '2025_08_19__04_59_13',

    'finetune_full_v1_1_1000': '2025_05_29__18_12_31',
    'finetune_full_v2_1_1000': '2025_05_09__16_58_31',
    'finetune_full_v3_1_1000': '2025_05_26__11_35_26',
    'finetune_full_v4_1_1000': '2025_08_20__16_24_57',
    'finetune_full_v5_1_1000': '2025_08_21__11_49_02',    
    # BitFit models
    'finetune_bitfit_v1_1_1000': '2025_06_12__13_09_43',
    'finetune_bitfit_v2_1_1000': '2025_06_13__13_56_59',
    'finetune_bitfit_v3_1_1000': '2025_06_14__12_59_27',
    'finetune_bitfit_v4_1_1000': '2025_08_25__09_50_18',
    'finetune_bitfit_v5_1_1000': '2025_08_26__01_21_48',

    # Top-3 layers
    'finetune_top3_v1_1_1000': '2025_06_18__16_15_13',
    'finetune_top3_v2_1_1000': '2025_07_22__20_32_56',
    'finetune_top3_v3_1_1000': '2025_07_23__14_50_23',
    'finetune_top3_v4_1_1000': '2025_08_22__10_03_26',
    'finetune_top3_v5_1_1000': '2025_08_23__03_00_04',

    # LoRA
    'lora_r1_v1_1_1000': '2025_07_21__19_20_51',
    'lora_r2_v1_1_1000': '2025_07_18__17_45_10',
    'lora_r4_v1_1_1000': '2025_07_19__04_03_45',
    
    'lora_r8_v1_1_1000': '2025_06_24__16_10_32',
    'lora_r8_v2_1_1000': '2025_07_24__17_42_07',  # Same timestamp as lora_r8_v1
    'lora_r8_v3_1_1000': '2025_07_25__12_25_39',  # Same timestamp as lora_r8_v1 and v2

    'lora_r16_v1_1_1000': '2025_06_26__13_29_31',
    'lora_r32_v1_1_1000': '2025_07_17__02_02_46',
    'lora_r32a48_v1_1_1000': '2025_07_17__02_06_49',
    'lora_r48_v1_1_1000': '2025_07_15__20_55_42',
    'lora_r64_v1_1_1000': '2025_07_15__09_42_59',
    
    'lora_r106_v1_1_1000': '2025_07_25__14_36_14',  # New model with r=106, alpha=106
    'lora_r106_v2_1_1000': '2025_07_29__10_10_18',  # Same timestamp as lora_r106_v1
    'lora_r106_v3_1_1000': '2025_07_30__13_50_31',  # Same timestamp as lora_r106_v1 and v2
    'lora_r106_v4_1_1000': '2025_08_28__11_56_05',  # Same timestamp as lora_r106_v1 and v2
    'lora_r106_v5_1_1000': '2025_08_29__15_01_22',  # Same timestamp as lora_r106_v1 and v2

    'lora_r204_v1_1_1000': '2025_07_26__11_44_31',
}

# Special cases that don't follow standard patterns
SPECIAL_CONFIGS = {
    # Models with different base directories
    '3frozen_finetune_v6': {
        'base_dir': BASE_DIR + '11_03_finetune_3layers_v6/finetune_3layers_v6/',
        'for_loading': 'finetune_3layers_v6'
    },
    '3frozen_finetune_v7': {
        'base_dir': BASE_DIR + '11_03_finetune_3layers_v7/finetune_3layers_v7/',
        'for_loading': 'finetune_3layers_v7'
    },
    '3frozen_finetune_v8': {
        'base_dir': BASE_DIR + '11_03_finetune_3layers_v8/finetune_3layers_v8/',
        'for_loading': 'finetune_3layers_v8'
    },
    'finetune_top3_v1_1_1000': {
        'base_dir': BASE_DIR + 'finetune_top3_v1_1_1000/',  # Fixed: using correct directory
        'for_loading': 'finetune_top3_v1_1_1000'
    },
    # Models with different filename suffixes
    'full_finetune': {
        'suffix': '_pretrained',
    },

    'lora': {
        'suffix': '_pretrained_lora',
    },
}

# Models that use pretrained suffix
PRETRAINED_MODELS = [
    # === 10-90 GeV Pretrained Models ===
    'full_finetune', 'full_finetune_v1', 'full_finetune_v2', 'full_finetune_v3',
    # === 1-1000 GeV Pretrained Models ===
    'finetune_full_v1_1_1000', 'finetune_full_v2_1_1000', 'finetune_full_v3_1_1000', 'finetune_full_v4_1_1000', 'finetune_full_v5_1_1000',
    'finetune_bitfit_v1_1_1000', 'finetune_bitfit_v2_1_1000', 'finetune_bitfit_v3_1_1000', 'finetune_bitfit_v4_1_1000', 'finetune_bitfit_v5_1_1000',
    'finetune_top3_v1_1_1000', 'finetune_top3_v2_1_1000', 'finetune_top3_v3_1_1000', 'finetune_top3_v4_1_1000', 'finetune_top3_v5_1_1000',
]

# Models that use lora suffix
LORA_MODELS = ['lora_full_v1', 'lora_full_v2', 'lora_full_v3', 
                'lora_r1_v1_1_1000','lora_r2_v1_1_1000','lora_r4_v1_1_1000',
                'lora_r8_v1_1_1000', 'lora_r8_v2_1_1000', 'lora_r8_v3_1_1000',
                'lora_r16_v1_1_1000', 'lora_r32_v1_1_1000', 'lora_r32a48_v1_1_1000','lora_r48_v1_1_1000',
                'lora_r64_v1_1_1000', 
                'lora_r106_v1_1_1000', 'lora_r106_v2_1_1000', 'lora_r106_v3_1_1000', 'lora_r106_v4_1_1000', 'lora_r106_v5_1_1000',
                'lora_r204_v1_1_1000', ]

# Active strategies (uncomment to enable)
ACTIVE_STRATEGIES = [
    # === 1-1000GeV Active Models ===
    'vanilla_full_v1_1_1000',
    'vanilla_full_v2_1_1000',
    'vanilla_full_v3_1_1000',
    'vanilla_full_v4_1_1000',
    'vanilla_full_v5_1_1000',

    'finetune_full_v1_1_1000',
    'finetune_full_v2_1_1000',
    'finetune_full_v3_1_1000',
    'finetune_full_v4_1_1000',
    'finetune_full_v5_1_1000',
    
    'finetune_bitfit_v1_1_1000',
    'finetune_bitfit_v2_1_1000',
    'finetune_bitfit_v3_1_1000',
    'finetune_bitfit_v4_1_1000',
    'finetune_bitfit_v5_1_1000',

    'finetune_top3_v1_1_1000',
    'finetune_top3_v2_1_1000',
    'finetune_top3_v3_1_1000',
    'finetune_top3_v4_1_1000',
    'finetune_top3_v5_1_1000',


    
    # 'lora_r1_v1_1_1000',
    # 'lora_r2_v1_1_1000',
    # 'lora_r4_v1_1_1000',

    # 'lora_r8_v1_1_1000',
    # 'lora_r8_v2_1_1000',
    # 'lora_r8_v3_1_1000',

    # 'lora_r16_v1_1_1000',
    # 'lora_r32_v1_1_1000',
    # # 'lora_r32a48_v1_1_1000',
    # 'lora_r48_v1_1_1000',
    # 'lora_r64_v1_1_1000',

    'lora_r106_v1_1_1000', 
    'lora_r106_v2_1_1000',  # Same timestamp as lora_r106_v1
    'lora_r106_v3_1_1000',  # Same timestamp as lora_r106_v1 and v2
    'lora_r106_v4_1_1000',  # Same timestamp as lora_r106_v1 and v2
    'lora_r106_v5_1_1000',  # Same timestamp as lora_r106_v1 and v2

    # 'lora_r204_v1_1_1000',
]


def build_strategy_config(strategy_name: str) -> Dict[str, Any]:
    """Build configuration for a given strategy based on patterns and special cases."""
    
    # Get timestamp pattern
    if strategy_name not in CHECKPOINT_PATTERNS:
        raise ValueError(f"Strategy {strategy_name} not found in CHECKPOINT_PATTERNS")
    
    timestamp = CHECKPOINT_PATTERNS[strategy_name]
    
    # Get special configurations if any
    special = SPECIAL_CONFIGS.get(strategy_name, {})
    
    # Build base configuration
    config = {
        'for_loading': special.get('for_loading', strategy_name),
        'base_dir': special.get('base_dir', BASE_DIR + strategy_name + '/'),
        'wd_per_epoch': {},
        'kl_per_epoch': {}
    }
    
    # Build checkpoint filename
    suffix = ''
    if strategy_name in PRETRAINED_MODELS:
        suffix = '_pretrained'
    elif strategy_name in LORA_MODELS:
        suffix = '_pretrained_lora'
    
    config['ckpt_filename'] = f'CaloChallange_CD{timestamp}_ckpt_0.000000_{{step}}.pt{suffix}'
    
    return config


In [None]:
# Build final configuration dictionary
strategy_configs = {
    strategy: build_strategy_config(strategy) 
    for strategy in ACTIVE_STRATEGIES
}

# Print configuration summary
print(f"Loading data for {len(strategy_configs)} strategies...")

# Track loading statistics
loading_stats = {
    'total_attempts': 0,
    'wd_loaded': 0,
    'kl_loaded': 0,
    'failed_strategies': set(),
    'partially_failed': {}
}

# Loop over each strategy and training step to process checkpoints
for strategy, config in strategy_configs.items():
    strategy_failures = []
    
    for step in grad_step:
        loading_stats['total_attempts'] += 1
        
        # Build checkpoint path using the format string
        ckpt_filename = config['ckpt_filename'].format(step=step)
        ckpt_path = os.path.join(config['base_dir'], ckpt_filename)
        
        # Load Wasserstein distance using a custom function (assumed to be defined)
        wasserstein_dist = finals.load_metric(output_dir=ckpt_path, strategy=strategy, metric_name='wasserstein_dist')
        quantile_kl = finals.load_metric(output_dir=ckpt_path, strategy=strategy, metric_name='kl_divergences')
        
        # Track failures
        failures = []
        if wasserstein_dist is None:
            failures.append('WD')
        else:
            config['wd_per_epoch'][step] = wasserstein_dist
            loading_stats['wd_loaded'] += 1
            
        if quantile_kl is None:
            failures.append('KL')
        else:
            config['kl_per_epoch'][step] = quantile_kl
            loading_stats['kl_loaded'] += 1
        
        # Only print if something failed
        if failures:
            strategy_failures.append((step, failures))
    
    # Report strategy-level failures
    if strategy_failures:
        if len(strategy_failures) == len(grad_step):
            # Complete failure for this strategy
            loading_stats['failed_strategies'].add(strategy)
            print(f"\n❌ {strategy}: Failed to load any data")
        else:
            # Partial failure
            loading_stats['partially_failed'][strategy] = strategy_failures
            print(f"\n⚠️  {strategy}: Failed {len(strategy_failures)}/{len(grad_step)} steps")
            for step, failures in strategy_failures[:5]:  # Show first 5 failures
                print(f"   Step {step}: {', '.join(failures)} missing")
            if len(strategy_failures) > 5:
                print(f"   ... and {len(strategy_failures) - 5} more failures")

# Final summary
print("\n" + "="*60)
print("LOADING SUMMARY:")
print(f"Total attempts: {loading_stats['total_attempts']}")
print(f"Successful loads: WD={loading_stats['wd_loaded']}, KL={loading_stats['kl_loaded']}")
print(f"Completely failed strategies: {len(loading_stats['failed_strategies'])}")
print(f"Partially failed strategies: {len(loading_stats['partially_failed'])}")

if loading_stats['failed_strategies']:
    print(f"\nStrategies with no data: {', '.join(loading_stats['failed_strategies'])}")

# Show successfully loaded strategies count
successful_strategies = []
for strategy, config in strategy_configs.items():
    if strategy not in loading_stats['failed_strategies'] and config['wd_per_epoch']:
        successful_strategies.append(strategy)

if successful_strategies:
    print(f"\n✓ Successfully loaded data for {len(successful_strategies)} strategies")
print("="*60)

In [None]:
# # Loop over each strategy configuration to plot its Wasserstein distances
# for strategy, config in strategy_configs.items():
#     # Create a subtitle by formatting the strategy name (customize if needed)
#     subtitle = strategy.replace("_", " ").title()
    
#     fig = finals.plot_wasserstein_distances_features(
#         config['wd_per_epoch'],
#         title='WD_Features_EWC_every1k_epoch',
#         save_plots=False,
#         show_legend=True,
#         subtitle=subtitle,
#         figsize=(20, 20),
#     )

#     fig = finals.plot_wasserstein_distances_features(
#         config['kl_per_epoch'],
#         title='KL_Features_EWC_every1k_epoch',
#         save_plots=False,
#         show_legend=True,
#         subtitle=subtitle,
#         figsize=(20, 20),
#     )

In [None]:
# Automatically extract the metrics dictionaries and labels from strategy_configs
wd_dicts = []
kl_dicts = []

labels = []
labels_kl = []

for strategy, config in strategy_configs.items():
    wd_dicts.append(config['wd_per_epoch'])
    labels.append(config.get('label', strategy.replace('_', ' ').title()))

    kl_dicts.append(config['kl_per_epoch'])
    labels_kl.append(config.get('label', strategy.replace('_', ' ').title()))

# Call the analysis function with the unpacked list of dictionaries
results_wd = finals.analyze_and_plot_metrics(
    *wd_dicts,
    labels=labels,
    avg_method='geometric'  # Use geometric mean for KL divergence as well
)
results_kl = finals.analyze_and_plot_metrics(
    *kl_dicts,
    labels=labels_kl,
    avg_method='geometric'  # Use geometric mean for KL divergence as well
)


In [None]:
metric_names = [
    'Voxel Energy Spectrum',
    'Energy Ratio',
    'Visible Energy',
    'Occupancy',
    'Longitudinal Profile',
    'Radial Profile'
]

config_names = list(strategy_configs.keys())

# For Wasserstein
reorganized_data_wd = reorganized_data_wd = finals.reorganize_metrics(config_names, results_wd, 
                                                                    metric_names, 'Wasserstein Distances')

# For KL
reorganized_data_kl = finals.reorganize_metrics(config_names, results_kl, 
                                                metric_names, 'KL Divergences')

# finals plots

In [None]:
custom_groups = {
    # "From Scratch": ["vanilla_v4_w5k_10-90", "vanilla_v5_w5k_10-90", "vanilla_v6_w5k_10-90"],
    # "Full Finetuned": ["finetune_full_v1_w5k_10-90", "finetune_full_v2_w5k_10-90", "finetune_full_v3_w5k_10-90"],
   
    # "From Scratch": ["vanilla"],
    # "Full Finetuned": ["full_finetune"],
    # "Top Layers Finetuned": ["3frozen_finetune"],
    # "LoRA (16)": ["lora_full_v1", "lora_full_v2"],
    # "LoRA (4)": ["lora_full_v3"],

    # == 1_1000GeV ===
    "From scratch": [
        "vanilla_full_v1_1_1000",
        "vanilla_full_v2_1_1000", 
        "vanilla_full_v3_1_1000",
        "vanilla_full_v4_1_1000",
        "vanilla_full_v5_1_1000",
    ],
    "Full fine-tuned": [
        "finetune_full_v1_1_1000",
        "finetune_full_v2_1_1000", 
        "finetune_full_v3_1_1000",
        "finetune_full_v4_1_1000",
        "finetune_full_v5_1_1000",
    ],
    "BitFit": [
        "finetune_bitfit_v1_1_1000",
        "finetune_bitfit_v2_1_1000",
        "finetune_bitfit_v3_1_1000",
        "finetune_bitfit_v4_1_1000",
        "finetune_bitfit_v5_1_1000",
    ],

    "Top2": ["finetune_top3_v1_1_1000",
             "finetune_top3_v2_1_1000", 
             "finetune_top3_v3_1_1000",
             "finetune_top3_v4_1_1000",
             "finetune_top3_v5_1_1000",
             ],

    # "LoRA (r=1)": ["lora_r1_v1_1_1000"],
    # "LoRA (r=2)": ["lora_r2_v1_1_1000"],
    # "LoRA (r=4)": ["lora_r4_v1_1_1000"],
    # "LoRA (r=8)": ["lora_r8_v1_1_1000",
    #                "lora_r8_v2_1_1000", 
    #                "lora_r8_v3_1_1000"
    #                ],
    # "LoRA (r=16)": ["lora_r16_v1_1_1000"],
    # "LoRA (r=32)": ["lora_r32_v1_1_1000"],
    # # "LoRA (r=32)": ["lora_r32a48_v1_1_1000"],
    # "LoRA (r=48)": ["lora_r48_v1_1_1000"],
    # "LoRA (r=64)": ["lora_r64_v1_1_1000"],

    "LoRA R106": ["lora_r106_v1_1_1000",
                    "lora_r106_v2_1_1000", 
                    "lora_r106_v3_1_1000",
                    "lora_r106_v4_1_1000",
                    "lora_r106_v5_1_1000",
                    ],
    # "LoRA (r=204)": ["lora_r204_v1_1_1000"],
}

In [None]:
x_max = 1.05e5  # Set x-axis limit to 1M
title = 'Incident Energy: 1 - 1000 GeV'

# 1) Define your selected_group as a list of dicts,
#    using the Okabe–Ito color‐blind–safe palette for publications:
from matplotlib.colors import to_hex

# Generate plasma colors for LoRA models
lora_colors = [to_hex(c) for c in plt.cm.plasma(np.linspace(0.1, 0.9, 8))]

selected_group = [
    {"name": "From scratch", "color": "#0D3B66"},
    {"name": "Full fine-tuned", "color": "#C03221"},
    {"name": "BitFit", "color": "#00A676"},
    {"name": "Top2", "color": "#995FA3"},
    
    {"name": "LoRA R106", "color": "#393A10"},  # New model with r=106, alpha=106
]

# 2) Immediately build the two parallel lists your function needs:
names  = [g["name"]  for g in selected_group]
colors = [g["color"] for g in selected_group]

# 3) Call exactly as before, swapping in 'names' & 'colors':
finals.plot_reorganized_data(
    reorganized_data_wd,
    strategy_configs,
    # main_title=title,
    ylabel='Normalized WD',
    save_plot=False,
    filename='for_paper/finals_1-1000GeV_normalised_wd.pdf',
    x_max=x_max,
    group_prefixes=custom_groups,
    selected_group=names,   # list of group keys
    colors=colors,          # matching, colorblind‐safe list
    use_weighted_average=True,
    weight_method='Geometric',
)

finals.plot_reorganized_data(
    reorganized_data_kl,
    strategy_configs,
    # main_title=title,
    ylabel='KLD',
    save_plot=False,
    filename='for_paper/finals_1-1000GeV_quantile_kl.pdf',
    x_max=x_max,
    group_prefixes=custom_groups,
    selected_group=names,
    colors=colors,
    use_weighted_average=True,
    weight_method='Geometric',
)

### PEFT analysis

In [None]:
# selected_group = [
#     {"name": "From Scratch", "params": "524.0K (100.0%)"},
#     {"name": "Full Finetuned", "params": "524.0K (100.0%)"},  
#     {"name": "BitFit", "params": "87.38K (17.30%)"},             
#     {"name": "Top3", "params": "220.50K (43.67%)"},             

#     # {"name": "LoRA (r=1)", "params": "2.67K (0.51%)"},          
#     # {"name": "LoRA (r=2)", "params": "5.14K (1.01%)"},    
#     # {"name": "LoRA (r=4)", "params": "10.27K (1.99%)"},                
#     {"name": "LoRA (r=8)", "params": "20.54K (3.91%)"},          
#     # {"name": "LoRA (r=16)", "params": "41.1K (7.53%)"},        
#     # {"name": "LoRA (r=32)", "params": "82.18K (14.00%)"},        
#     # {"name": "LoRA (r=48)", "params": "123.26K (19.62%)"},        
#     # {"name": "LoRA (r=64)", "params": "164.35K (24.56%)"},        
# ]

selected_group = [
    {"name": "From Scratch", "params": "504.92K "},
    {"name": "Full Finetuned", "params": "504.92K "},  
    # {"name": "BitFit", "params": "87.38K "},             
    # {"name": "Top3", "params": "220.50K "},             

    # {"name": "LoRA (r=1)", "params": "2.67K"},          
    # {"name": "LoRA (r=2)", "params": "5.14K"},    
    # {"name": "LoRA (r=4)", "params": "10.27K"},                
    {"name": "LoRA (r=8)", "params": "20.54K"},          
    # {"name": "LoRA (r=16)", "params": "41.1K"},        
    # {"name": "LoRA (r=32)", "params": "82.18K "},        
    # {"name": "LoRA (r=48)", "params": "123.26K "},        
    # {"name": "LoRA (r=64)", "params": "164.35K "},     
    {"name": "LoRA (r=106)", "params": "272,21K "},     
    # {"name": "LoRA (r=204)", "params": "523.87K "},     
]

names = [g["name"] for g in selected_group]

# Create a mapping from name to params for the table
name_to_params = {g["name"]: g["params"] for g in selected_group}

In [None]:
cols = [100, 1_000, 10_000, 100_000]
# cols = [500, 5_000, 50_000]

agg_mean, agg_std = finals.generate_overall_table(
    reorganized_data_wd,
    strategy_configs,
    group_prefixes=custom_groups,   # <-- your dict of lists
    selected_group=names,           # <-- the list of group names
    cols = cols,
    weight_method='Geometric',
    metric_names=metric_names,
    use_weighted_average=True
)
finals.print_overall_table(agg_mean, agg_std, name_to_params=name_to_params,
                            save_latex=True,
                            latex_path="results/for_paper/finals_1-1000GeV_overall_table_PEFT.tex")

fig, ax = finals.plot_performance_vs_params(
    agg_mean, 
    agg_std, 
    name_to_params,
    # colors=colors,          
    metric_name="1 - Normalized Wasserstein Distance",  # or "Accuracy" if higher is better
    save_path="results/for_paper/performance_vs_params.pdf"
)


In [None]:
# Now pass them in correctly:
agg_mean, agg_std = finals.generate_overall_table(
    reorganized_data_kl,
    strategy_configs,
    group_prefixes=custom_groups,   # <-- your dict of lists
    selected_group=names,           # <-- the list of group names
    cols = cols,
    weight_method='Simple',
    metric_names=metric_names,
    use_weighted_average=True
)

finals.print_overall_table(agg_mean, agg_std, name_to_params=name_to_params,
                            save_latex=True,
                            latex_path="results/for_paper/finals_1-1000GeV_overall_table_PEFT_kl.tex")
# fig, ax = finals.plot_performance_vs_params(
#     agg_mean, 
#     agg_std, 
#     name_to_params,
#     # colors=colors,          

#     metric_name="Kullback–Leibler divergence",  # or "Accuracy" if higher is better
#     # save_path="performance_vs_params.pdf"
# )

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
os.makedirs("./results/for_paper/lora", exist_ok=True)

# Create figure with two subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))

# Data from your tables (using average values)
methods_data = {
    'From Scratch': {'params': 504.92, 'wasserstein': 0.110, 'kl': 0.218, 'color': 'gray', 'marker': 'v'},
    'Full FT': {'params': 504.92, 'wasserstein': 0.094, 'kl': 0.138, 'color': 'purple', 'marker': 's'},
    'BitFit': {'params': 87.38, 'wasserstein': 0.103, 'kl': 0.146, 'color': 'green', 'marker': 's'},
    'Top3': {'params': 220.50, 'wasserstein': 0.099, 'kl': 0.147, 'color': 'blue', 'marker': 's'},
    'LoRA r=8': {'params': 20.54, 'wasserstein': 0.135, 'kl': 0.199, 'color': 'red', 'marker': 'o'},
    'LoRA r=48': {'params': 123.26, 'wasserstein': 0.137, 'kl': 0.190, 'color': 'red', 'marker': 'o'},
    'LoRA r=64': {'params': 164.35, 'wasserstein': 0.158, 'kl': 0.257, 'color': 'red', 'marker': 'o'},
    'LoRA r=106': {'params': 300.35, 'wasserstein': 0.119, 'kl': 0.215, 'color': 'red', 'marker': 'o'},
    'LoRA r=204': {'params': 523.87, 'wasserstein': 0.123, 'kl': np.nan, 'color': 'darkred', 'marker': 'o'},
}

# Subplot 1: Wasserstein Distance
for method, data in methods_data.items():
    if not np.isnan(data['wasserstein']):
        if 'LoRA' in method:
            ax1.scatter(data['params'], data['wasserstein'], s=200, c=data['color'], 
                       alpha=0.7, marker=data['marker'], edgecolors='black', linewidth=1.5)
        else:
            ax1.scatter(data['params'], data['wasserstein'], s=250, c=data['color'], 
                       alpha=0.8, marker=data['marker'], edgecolors='black', linewidth=2)
        
        # Annotations with better positioning
        if method == 'LoRA r=204':
            ax1.annotate(method, (data['params'], data['wasserstein']), 
                        xytext=(-80, 15), textcoords='offset points', fontsize=11, fontweight='bold',
                        bbox=dict(boxstyle="round,pad=0.3", facecolor='yellow', alpha=0.7),
                        arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0.3'))
        elif method == 'Full FT':
            ax1.annotate(method, (data['params'], data['wasserstein']), 
                        xytext=(-80, -25), textcoords='offset points', fontsize=11, fontweight='bold')
        elif method == 'From Scratch':
            ax1.annotate(method, (data['params'], data['wasserstein']), 
                        xytext=(-100, 10), textcoords='offset points', fontsize=10)
        else:
            ax1.annotate(method.replace('LoRA ', ''), (data['params'], data['wasserstein']), 
                        xytext=(10, 5), textcoords='offset points', fontsize=10)

# Add reference lines
ax1.axvline(x=220.5, color='blue', linestyle='--', alpha=0.3, linewidth=1)
ax1.axvline(x=523.87, color='red', linestyle='--', alpha=0.3, linewidth=1)

# Highlight the paradox
ax1.annotate('', xy=(523.87, 0.123), xytext=(220.5, 0.099),
            arrowprops=dict(arrowstyle='<->', color='black', lw=2, alpha=0.5))
ax1.text(370, 0.111, 'More params,\nworse performance!', 
         ha='center', fontsize=12, fontweight='bold',
         bbox=dict(boxstyle="round,pad=0.5", facecolor='white', edgecolor='red', linewidth=2))

ax1.set_xlabel('Trainable Parameters (K)', fontsize=14, fontweight='bold')
ax1.set_ylabel('Wasserstein Distance (↓ better)', fontsize=14, fontweight='bold')
ax1.set_title('Parameter Efficiency: Wasserstein Distance', fontsize=16, fontweight='bold')
ax1.grid(True, alpha=0.3)
ax1.set_xlim(-50, 600)
ax1.set_ylim(0.08, 0.17)

# Subplot 2: KL Divergence
for method, data in methods_data.items():
    if not np.isnan(data.get('kl', np.nan)):
        if 'LoRA' in method:
            ax2.scatter(data['params'], data['kl'], s=200, c=data['color'], 
                       alpha=0.7, marker=data['marker'], edgecolors='black', linewidth=1.5)
        else:
            ax2.scatter(data['params'], data['kl'], s=250, c=data['color'], 
                       alpha=0.8, marker=data['marker'], edgecolors='black', linewidth=2)
        
        # Annotations
        if method == 'Full FT':
            ax2.annotate(method, (data['params'], data['kl']), 
                        xytext=(-80, -25), textcoords='offset points', fontsize=11, fontweight='bold')
        elif method == 'From Scratch':
            ax2.annotate(method, (data['params'], data['kl']), 
                        xytext=(-100, 10), textcoords='offset points', fontsize=10)
        else:
            ax2.annotate(method.replace('LoRA ', ''), (data['params'], data['kl']), 
                        xytext=(10, 5), textcoords='offset points', fontsize=10)

ax2.set_xlabel('Trainable Parameters (K)', fontsize=14, fontweight='bold')
ax2.set_ylabel('KL Divergence (↓ better)', fontsize=14, fontweight='bold')
ax2.set_title('Parameter Efficiency: KL Divergence', fontsize=16, fontweight='bold')
ax2.grid(True, alpha=0.3)
ax2.set_xlim(-50, 600)
ax2.set_ylim(0.12, 0.27)

# Main title
# plt.suptitle('The LoRA Paradox: More Parameters, Worse Performance', fontsize=18, fontweight='bold', y=1.02)

# Add legend
from matplotlib.lines import Line2D
legend_elements = [
    Line2D([0], [0], marker='o', color='w', markerfacecolor='red', markersize=10, label='LoRA variants'),
    Line2D([0], [0], marker='s', color='w', markerfacecolor='blue', markersize=10, label='Top3'),
    Line2D([0], [0], marker='s', color='w', markerfacecolor='green', markersize=10, label='BitFit'),
    Line2D([0], [0], marker='s', color='w', markerfacecolor='purple', markersize=10, label='Full fine-tuned'),
    Line2D([0], [0], marker='v', color='w', markerfacecolor='gray', markersize=10, label='From scratch')
]
ax1.legend(handles=legend_elements, loc='upper right', fontsize=10)

plt.tight_layout()
plt.savefig('./results/for_paper/lora/lora_paradox_performance_vs_params.pdf', dpi=300, bbox_inches='tight')
plt.show()



### anlaysis weights 

In [None]:
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path
import utils.finetune as ft
import configs

from models.epicVAE_nflows_kDiffusion import epicVAE_nFlow_kDiffusion

def load_pretrained_model(cfg, checkpoint_path, use_ema=True):
    """Load pretrained model weights"""
    model = epicVAE_nFlow_kDiffusion(cfg).to(cfg.device)
    checkpoint = torch.load(checkpoint_path, map_location=cfg.device, weights_only=False)
    
    if use_ema and 'others' in checkpoint and 'model_ema' in checkpoint['others']:
        model.load_state_dict(checkpoint['others']['model_ema'], strict=True)
    else:
        model.load_state_dict(checkpoint['state_dict'], strict=True)
    
    return model

def load_finetuned_models(cfg, base_path, use_ema=True):
    """Load all fine-tuned model variants"""
    models = {}
    
    # Define checkpoint paths for each method
    checkpoints = {
        'Full FT': f'{base_path}/finetune/CaloChallange_CD2025_05_09__16_58_31/ckpt_0.000000_500000.pt',
        'BitFit': f'{base_path}/finetune/CaloChallange_CD2025_06_12__13_09_43/ckpt_0.000000_500000.pt',
        'Top3':   f'{base_path}/finetune/CaloChallange_CD2025_07_22__20_32_56/ckpt_0.000000_500000.pt',
        
        'LoRA r=204': f'{base_path}/finetune/lora/CaloChallange_CD2025_07_26__11_44_31/ckpt_0.000000_50000.pt',
        'LoRA r=106': f'{base_path}/finetune/lora/CaloChallange_CD2025_07_25__14_36_14/ckpt_0.000000_50000.pt',

        'LoRA r=64': f'{base_path}/finetune/lora/CaloChallange_CD2025_07_15__09_42_59/ckpt_0.000000_50000.pt',

        'LoRA r=48': f'{base_path}/finetune/lora/CaloChallange_CD2025_07_15__20_55_42/ckpt_0.000000_50000.pt',
        'LoRA r=16': f'{base_path}/finetune/lora/CaloChallange_CD2025_06_26__13_29_31/ckpt_0.000000_50000.pt',

        'LoRA r=8': f'{base_path}/finetune/lora/CaloChallange_CD2025_06_24__16_10_32/ckpt_0.000000_50000.pt',
        # 'LoRA r=8': f'{base_path}/finetune/lora/CaloChallange_CD2025_07_24__13_39_58/ckpt_0.000000_140000.pt',
        'LoRA r=4': f'{base_path}/finetune/lora/CaloChallange_CD2025_07_19__04_03_45/ckpt_0.000000_50000.pt',
        'LoRA r=2': f'{base_path}/finetune/lora/CaloChallange_CD2025_07_18__17_45_10/ckpt_0.000000_50000.pt',
        'LoRA r=1': f'{base_path}/finetune/lora/CaloChallange_CD2025_07_21__19_20_51/ckpt_0.000000_50000.pt',

    }
    
    # Load pretrained model first
    pretrained_model = load_pretrained_model(cfg, cfg.diffusion_pretrained_model_path, use_ema)
    
    for method_name, ckpt_path in checkpoints.items():
        if not Path(ckpt_path).exists():
            print(f"Checkpoint not found: {ckpt_path}")
            continue
            
        print(f"\nLoading {method_name}...")
        
        if 'LoRA' in method_name:
            # Extract rank from method name
            rank = int(method_name.split('=')[1])
            
            # Create model with LoRA architecture
            model = epicVAE_nFlow_kDiffusion(cfg).to(cfg.device)
            
            # Load pretrained weights first
            pretrained_ckpt = torch.load(cfg.diffusion_pretrained_model_path, map_location=cfg.device, weights_only=False)
            if use_ema:
                model.load_state_dict(pretrained_ckpt['others']['model_ema'], strict=True)
            else:
                model.load_state_dict(pretrained_ckpt['state_dict'], strict=True)
            
            # Apply LoRA architecture
            for i in range(6):  # Assuming 6 layers
                layer = model.diffusion.inner_model.layers[i]
                model.diffusion.inner_model.layers[i] = ft.apply_lora(layer, rank=rank, alpha=rank)
            
            # Load LoRA weights
            lora_ckpt = torch.load(ckpt_path, map_location=cfg.device, weights_only=False)
            if use_ema:
                lora_state = lora_ckpt['others']['model_ema']
            else:
                lora_state = lora_ckpt['state_dict']
                
            # Copy LoRA parameters
            for name, param in model.named_parameters():
                if '_layer.A' in name or '_layer.B' in name:
                    if name in lora_state:
                        param.data.copy_(lora_state[name])
                        
        else:
            # Load regular fine-tuned model
            model = epicVAE_nFlow_kDiffusion(cfg).to(cfg.device)
            checkpoint = torch.load(ckpt_path, map_location=cfg.device, weights_only=False)
            
            if use_ema:
                model.load_state_dict(checkpoint['others']['model_ema'], strict=True)
            else:
                model.load_state_dict(checkpoint['state_dict'], strict=True)
        
        models[method_name] = model
    
    return pretrained_model, models

def calculate_modification_magnitude(pretrained_model, finetuned_models, num_layers=6):
    """Calculate layer-wise modification magnitudes for each method"""
    
    modifications = {}
    
    with torch.no_grad():
        for method_name, model_ft in finetuned_models.items():
            layer_mods = []
            
            for layer_idx in range(num_layers):
                # Get pretrained layer
                pretrained_layer = pretrained_model.diffusion.inner_model.layers[layer_idx]
                finetuned_layer = model_ft.diffusion.inner_model.layers[layer_idx]
                
                if method_name == "BitFit":
                    # Only bias changes
                    if hasattr(pretrained_layer, '_layer'):
                        bias_pre = pretrained_layer._layer.bias
                        bias_ft = finetuned_layer._layer.bias
                    else:
                        bias_pre = pretrained_layer.bias
                        bias_ft = finetuned_layer.bias
                    
                    if bias_pre is not None and bias_ft is not None:
                        mod = torch.norm(bias_ft - bias_pre).item()
                    else:
                        mod = 0.0
                        
                elif "LoRA" in method_name:
                    # LoRA modification magnitude
                    if hasattr(finetuned_layer, '_layer') and hasattr(finetuned_layer._layer, 'A'):
                        lora_layer = finetuned_layer._layer
                        # lora_layer.alpha = 1  # Now we set alpha after lora_layer is defined
                        # Calculate effective LoRA weight change
                        lora_weight = (lora_layer.B @ lora_layer.A) * (lora_layer.alpha / lora_layer.rank)
                        mod = torch.norm(lora_weight).item()
                    else:
                        mod = 0.0
                        
                elif method_name == "Top3":
                    # Only last 3 layers are modified
                    if layer_idx >= 3:
                        if hasattr(pretrained_layer, '_layer'):
                            W_pre = pretrained_layer._layer.weight
                            W_ft = finetuned_layer._layer.weight
                            bias_pre = pretrained_layer._layer.bias
                            bias_ft = finetuned_layer._layer.bias
                        else:
                            W_pre = pretrained_layer.weight
                            W_ft = finetuned_layer.weight
                            bias_pre = pretrained_layer.bias
                            bias_ft = finetuned_layer.bias
                        
                        # Combined weight and bias modification
                        weight_mod = torch.norm(W_ft - W_pre).item()
                        bias_mod = torch.norm(bias_ft - bias_pre).item() if bias_pre is not None else 0
                        mod = weight_mod + bias_mod
                    else:
                        mod = 0.0
                        
                elif method_name == "Full FT":
                    # All parameters changed
                    if hasattr(pretrained_layer, '_layer'):
                        W_pre = pretrained_layer._layer.weight
                        W_ft = finetuned_layer._layer.weight
                        bias_pre = pretrained_layer._layer.bias
                        bias_ft = finetuned_layer._layer.bias
                    else:
                        W_pre = pretrained_layer.weight
                        W_ft = finetuned_layer.weight
                        bias_pre = pretrained_layer.bias
                        bias_ft = finetuned_layer.bias
                    
                    weight_mod = torch.norm(W_ft - W_pre).item()
                    bias_mod = torch.norm(bias_ft - bias_pre).item() if bias_pre is not None else 0
                    mod = weight_mod + bias_mod
                    
                layer_mods.append(mod)
                
            modifications[method_name] = layer_mods
            print(f"{method_name}: {layer_mods}")
    
    return modifications

def create_modification_heatmap(modifications, save_path=None, normalize_by='method'):
    """
    Create layer-wise modification magnitude heatmap
    
    Args:
        modifications: dict with method names as keys and lists of layer modifications as values
        save_path: path to save the figure (optional)
        normalize_by: 'method' (normalize each row) or 'layer' (normalize each column)
    """
    
    # Create DataFrame
    df = pd.DataFrame(modifications, index=[f"Layer {i}" for i in range(len(list(modifications.values())[0]))])
    df = df.T  # Transpose so methods are rows
    # Add average modification per layer
    df['Average'] = df.mean(axis=1)
    # Normalize based on choice
    if normalize_by == 'method':
        # Normalize each method (row) separately
        df_norm = df.div(df.max(axis=1), axis=0).fillna(0)
    elif normalize_by == 'layer':
        # Normalize each layer (column) separately
        df_norm = df.div(df.max(axis=0), axis=1).fillna(0)
    else:
        # No normalization
        df_norm = df
    
    # Create figure
    plt.figure(figsize=(10, 6))
    
    # Create custom colormap (white to red)
    cmap = sns.color_palette("Reds", as_cmap=True)
    max_val = df.values.max()

    # Create heatmap
    ax = sns.heatmap(df_norm, 
                     cmap=cmap,
                     cbar_kws={'label': 'Relative Modification Magnitude'},
                     linewidths=0.5,
                     annot=True,
                     fmt='.2f',
                     vmin=0,
                     vmax=1.0 if normalize_by else max_val)
    ax.vlines(len(df_norm.columns) - 1, *ax.get_ylim(), colors='black', linewidth=2, linestyles='dashed')
    # Styling
    plt.xlabel('Layer', fontsize=22, fontweight='bold')
    plt.ylabel('Method', fontsize=22, fontweight='bold')
    plt.title('Layer-wise Modification Magnitude', fontsize=24, fontweight='bold')
    
    # Rotate y-axis labels for better readability
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Heatmap saved to {save_path}")
    
    plt.show()
    
    return df_norm

def analyze_modifications(modifications):
    """Print analysis of modification patterns"""
    
    df = pd.DataFrame(modifications)
    
    print("\n=== Modification Analysis ===")
    print("\nTotal modification per method:")
    total_mods = df.sum(axis=0).sort_values(ascending=False)
    for method, total in total_mods.items():
        print(f"  {method}: {total:.4f}")
    
    print("\nAverage modification per layer:")
    avg_mods = df.mean(axis=1)
    for layer_idx, avg in enumerate(avg_mods):
        print(f"  Layer {layer_idx}: {avg:.4f}")
    
    print("\nModification concentration (std/mean):")
    for method in df.columns:
        concentration = df[method].std() / (df[method].mean() + 1e-8)
        print(f"  {method}: {concentration:.2f} {'(concentrated)' if concentration > 1 else '(distributed)'}")

# Main execution
if __name__ == "__main__":
    # Configuration
    cfg = configs.Configs()
    cfg.device = torch.device("cuda:0")
    
    # Paths
    BASE_PATH = '/data/dust/user/valentel/beegfs.migration/dust/logs/MyCaloTransfer_diffusionweights'
    
    # Load all models
    print("Loading models...")
    pretrained_model, finetuned_models = load_finetuned_models(cfg, BASE_PATH, use_ema=True)
    
    # Calculate modifications
    print("\nCalculating modification magnitudes...")
    modifications = calculate_modification_magnitude(pretrained_model, finetuned_models)
    
    # Analyze modifications
    analyze_modifications(modifications)
    
    # Create heatmap
    print("\nCreating modification heatmap...")
    df_norm = create_modification_heatmap(
        modifications, 
        # save_path='layer_modification_heatmap.pdf',
        normalize_by=None  # or 'layer' or None
    )
    
    # Optional: Save raw modification values
    df_raw = pd.DataFrame(modifications)
    df_raw.to_csv('modification_magnitudes.csv')
    print("\nRaw modification values saved to modification_magnitudes.csv")

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, Any, List, Tuple, Optional
import os

def inverse_lora_experiment(pretrained_model, finetuned_models, num_layers: int = 6,
                           save_path: str = './results/for_paper/lora/',
                           error_threshold: Optional[float] = None,
                           custom_ranks: Optional[List[int]] = None):
    """
    Calculate the best LoRA approximation of known weight deltas using SVD.
    Tests representation capacity vs optimization difficulty.
    
    Args:
        pretrained_model: The pretrained base model
        finetuned_models: Dictionary of finetuned model variants
        num_layers: Number of layers to analyze
        save_path: Directory to save figures
        error_threshold: Optional error threshold to display (e.g., 0.05 for 5%)
        custom_ranks: Optional list of ranks to test (defaults to powers of 2)
    
    Returns:
        Dictionary containing reconstruction errors and optimal factors
    """
    print("\n=== Inverse LoRA Experiment ===")
    
    if 'Full FT' not in finetuned_models:
        print("Full FT model not found, skipping inverse LoRA experiment")
        return {}
        
    full_ft_model = finetuned_models['Full FT']
    inverse_results = {}
    
    with torch.no_grad():
        for layer_idx in range(num_layers):
            pretrained_layer = pretrained_model.diffusion.inner_model.layers[layer_idx]
            finetuned_layer = full_ft_model.diffusion.inner_model.layers[layer_idx]
            
            # Extract weights
            if hasattr(pretrained_layer, '_layer'):
                W_pre = pretrained_layer._layer.weight
                W_ft = finetuned_layer._layer.weight
            else:
                W_pre = pretrained_layer.weight
                W_ft = finetuned_layer.weight
            
            # Compute true weight update
            delta_W = W_ft - W_pre
            
            # SVD to get optimal low-rank approximation
            U, S, V = torch.svd(delta_W)
            
            layer_results = {
                'reconstruction_errors': {},
                'optimal_BA': {},
                'weight_shape': delta_W.shape,
                'singular_values': S.cpu().numpy()  # Store for analysis
            }
            
            # Generate test ranks more systematically
            max_rank = min(delta_W.shape)
            if custom_ranks is not None:
                test_ranks = [r for r in custom_ranks if r <= max_rank]
            else:
                # Use powers of 2 plus some intermediate values for smoother curves
                test_ranks = []
                r = 1
                while r <= max_rank:
                    test_ranks.append(r)
                    if r < max_rank:
                        # Add intermediate point for smoother visualization
                        intermediate = min(int(r * 1.5), max_rank)
                        if intermediate not in test_ranks and intermediate < r * 2:
                            test_ranks.append(intermediate)
                    r *= 2
                if max_rank not in test_ranks:
                    test_ranks.append(max_rank)
                test_ranks = sorted(test_ranks)
            
            for rank in test_ranks:
                # Optimal low-rank approximation using truncated SVD
                reconstruction = U[:, :rank] @ torch.diag(S[:rank]) @ V[:, :rank].T
                
                # Calculate relative Frobenius norm error
                error = torch.norm(reconstruction - delta_W) / torch.norm(delta_W)
                layer_results['reconstruction_errors'][rank] = error.item()
                
                # Store the LoRA factors for potential future use
                B_optimal = U[:, :rank] * torch.sqrt(S[:rank].unsqueeze(0))
                A_optimal = torch.sqrt(S[:rank].unsqueeze(1)) * V[:, :rank].T
                layer_results['optimal_BA'][rank] = (B_optimal, A_optimal)
                
            inverse_results[f'Layer {layer_idx}'] = layer_results
            
            print(f"\nLayer {layer_idx} (shape: {delta_W.shape}) - Optimal LoRA reconstruction errors:")
            # Print key ranks only to avoid clutter
            key_ranks = [r for r in [1, 4, 16, 64, 256, max_rank] if r in layer_results['reconstruction_errors']]
            for rank in key_ranks:
                error = layer_results['reconstruction_errors'][rank]
                # Handle very small errors properly
                if error < 0.0001:
                    print(f"  Rank {rank}: {error:.2e}")
                else:
                    print(f"  Rank {rank}: {error:.2%}")
    
    # Create publication-quality visualization
    create_reconstruction_plot(inverse_results, num_layers, save_path, error_threshold)
    
    # Optional: Create singular value decay plot
    create_singular_value_plot(inverse_results, num_layers, save_path)
    
    return inverse_results

def create_reconstruction_plot(inverse_results: Dict, num_layers: int, 
                              save_path: str, error_threshold: Optional[float] = None):
    """Create publication-quality reconstruction error plot."""
    
    # Set up the figure with publication-quality settings
    plt.figure(figsize=(10, 6))
    
    # Use a professional color palette
    colors = plt.cm.viridis(np.linspace(0.15, 0.85, num_layers))
    
    # Plot each layer's reconstruction error
    for layer_idx in range(num_layers):
        if f'Layer {layer_idx}' in inverse_results:
            errors = inverse_results[f'Layer {layer_idx}']['reconstruction_errors']
            ranks = sorted(errors.keys())
            values = [errors[r] for r in ranks]
            
            plt.semilogy(ranks, values, 
                        marker='o', 
                        label=f'Layer {layer_idx}', 
                        linewidth=2.5,
                        markersize=6,
                        color=colors[layer_idx],
                        markeredgewidth=0.5,
                        markeredgecolor='white',
                        alpha=0.9)
    
    # Use Unicode for y-axis label (more reliable than LaTeX)
    plt.xlabel('LoRA Rank', fontsize=24, fontweight='normal')
    plt.ylabel('Reconstruction Error [εᵣ]', fontsize=24, fontweight='normal')
    
    # Remove grid as requested
    plt.grid(False)
    
    # Add subtle threshold line if specified
    if error_threshold is not None:
        ax = plt.gca()
        plt.axhline(y=error_threshold, 
                   color='#C03221',
                   linestyle='--', 
                   linewidth=1.5,
                   alpha=0.7,
                   zorder=0)
        
        # Get x limits after plotting
        xlims = ax.get_xlim()
        x_pos = 8  # 60% from the right edge

        plt.text(x_pos, error_threshold * 1.5, 
        f'{error_threshold:.0%} error',
        horizontalalignment='center',
        fontsize=16,
        color='#C03221',
        alpha=0.9,)
    
    # Configure legend
    plt.legend(loc='best', 
              frameon=False,
              fontsize=20,
              ncol=2 if num_layers > 6 else 1,
              columnspacing=1.0,
              handlelength=1.5)
    
    # Set reasonable x-axis limits
    ax = plt.gca()
    ax.set_xscale('log', base=2)  # Use log2 scale for ranks
    
    # Format tick labels
    ax.tick_params(axis='both', which='major', labelsize=16)
    ax.tick_params(axis='both', which='minor', labelsize=14)
    
    # Ensure tight layout
    plt.tight_layout()
    
    # Save figure
    os.makedirs(save_path, exist_ok=True)
    filepath = os.path.join(save_path, 'inverse_lora_reconstruction_error.pdf')
    plt.savefig(filepath, dpi=300, bbox_inches='tight')
    print(f"Saved reconstruction error plot to {filepath}")
    plt.show()

def create_singular_value_plot(inverse_results: Dict, num_layers: int, save_path: str):
    """
    Create plot showing singular value decay for each layer.
    
    This plot reveals the intrinsic dimensionality of the weight updates
    and explains why certain LoRA ranks work better than others.
    """
    
    plt.figure(figsize=(10, 6))
    colors = plt.cm.viridis(np.linspace(0.15, 0.85, num_layers))
    
    for layer_idx in range(num_layers):
        if f'Layer {layer_idx}' in inverse_results:
            singular_values = inverse_results[f'Layer {layer_idx}']['singular_values']
            
            # Normalize by first singular value for comparison across layers
            normalized_sv = singular_values / singular_values[0]
            
            plt.semilogy(range(1, len(normalized_sv) + 1), 
                        normalized_sv,
                        linewidth=2.5,
                        label=f'Layer {layer_idx}',
                        color=colors[layer_idx],
                        alpha=0.9)
    
    plt.xlabel('Singular Value Index', fontsize=24)
    plt.ylabel('Normalized Singular Value', fontsize=24)
    
    # No grid as requested
    plt.grid(False)
    
    # Add reference lines for interpretation
    plt.axhline(y=0.1, color='gray', linestyle=':', linewidth=1, alpha=0.7)
    plt.axhline(y=0.01, color='gray', linestyle=':', linewidth=1, alpha=0.7)
    
    # Add text annotations for reference lines
    ax = plt.gca()
    xlims = ax.get_xlim()
    plt.text(xlims[1] * 0.98, 0.11, '10%', 
            horizontalalignment='right', fontsize=12, color='gray', alpha=0.7)
    plt.text(xlims[1] * 0.98, 0.011, '1%', 
            horizontalalignment='right', fontsize=12, color='gray', alpha=0.7)
    
    plt.legend(loc='best', frameon=False, fontsize=14)
    
    ax.tick_params(axis='both', which='major', labelsize=14)
    
    plt.tight_layout()
    
    # Save figure
    filepath = os.path.join(save_path, 'singular_value_decay.pdf')
    plt.savefig(filepath, dpi=300, bbox_inches='tight')
    print(f"Saved singular value decay plot to {filepath}")
    plt.show()

def analyze_effective_rank(inverse_results: Dict, energy_threshold: float = 0.95):
    """
    Analyze the effective rank needed to capture a given percentage of the update energy.
    
    Args:
        inverse_results: Results from inverse_lora_experiment
        energy_threshold: Percentage of energy to capture (0.95 = 95%)
    
    Returns:
        Dictionary with effective ranks for each layer
    """
    effective_ranks = {}
    
    for layer_name, layer_data in inverse_results.items():
        if 'singular_values' in layer_data:
            sv = layer_data['singular_values']
            
            # Calculate cumulative energy (sum of squared singular values)
            sv_squared = sv ** 2
            cumulative_energy = np.cumsum(sv_squared) / np.sum(sv_squared)
            
            # Find rank that captures desired energy
            effective_rank = np.argmax(cumulative_energy >= energy_threshold) + 1
            
            effective_ranks[layer_name] = {
                'effective_rank': effective_rank,
                'total_rank': len(sv),
                'energy_at_effective_rank': cumulative_energy[effective_rank - 1],
                'compression_ratio': len(sv) / effective_rank
            }
            
            print(f"{layer_name}: Rank {effective_rank}/{len(sv)} captures {energy_threshold:.0%} of energy "
                  f"(compression ratio: {len(sv)/effective_rank:.1f}x)")
    
    return effective_ranks


# Example usage
if __name__ == "__main__":
    # Run the main experiment
    results = inverse_lora_experiment(
        pretrained_model=pretrained_model,
        finetuned_models=finetuned_models,
        num_layers=6,
        save_path='./results/for_paper/lora/',
        error_threshold=0.05,  # Show 5% error line if justified
        custom_ranks=None  # Use automatic rank selection
    )
    
    # Analyze effective ranks
    print("\n=== Effective Rank Analysis ===")
    effective_ranks = analyze_effective_rank(results, energy_threshold=0.95)
    
    print("\n=== Effective Rank Analysis (99% energy) ===")
    effective_ranks_99 = analyze_effective_rank(results, energy_threshold=0.99)

# For ShowerFlow only

In [None]:
from utils.plot_finals_sf import (
    run_complete_analysis, 
    create_showerflow_config
)

# Define your experiment folders
vanilla_folders = {
    42: 'ShowerFlow_2025_08_08__13_30_54',
    43: 'ShowerFlow_2025_05_06__18_00_45',
    44: 'ShowerFlow_2025_08_08__16_37_03', 
    45: 'ShowerFlow_2025_08_09__12_41_00',
    46: 'ShowerFlow_2025_08_09__17_09_12',
}

finetune_folders = {
    42: 'ShowerFlow_2025_08_08__11_48_39',
    43: 'ShowerFlow_2025_05_06__18_00_43',
    44: 'ShowerFlow_2025_08_08__12_34_58',
    45: 'ShowerFlow_2025_08_09__12_27_34',
    46: 'ShowerFlow_2025_08_09__16_37_25',
}

# Create configuration
config = create_showerflow_config(
    vanilla_folders=vanilla_folders,
    finetune_folders=finetune_folders
)

# Run the complete analysis - creates exactly 4 plots
run_complete_analysis(
    config=config,
    metric_files=['KL_Features_all_epochs.json', 'WD_Features_all_epochs.json'],
    output_dir="./results/for_paper/showerflow_analysis"
)

# Baseline Evaluation

In [None]:
import utils.paths_trainings_cleaned as paths
from utils.preprocessing_utils import read_hdf5_file

In [None]:
val_path = paths.GEANT4_PATH

# train_showers, train_incidents = read_hdf5_file(train_path)
val_ds = read_hdf5_file(val_path)
val_showers = np.array(val_ds['showers'])  # Convert to NumPy array
val_incidents = np.array(val_ds['incident'])
print(val_showers.shape)
print(val_incidents.shape)

In [None]:
import numpy as np
from utils.preprocessing_utils import cylindrical_histogram, plt_scatter_2

# Load data
train_path = '/data/dust/user/valentel/maxwell.merged/MyCaloTransfer/CaloTransfer/data/calo-challenge/preprocessing/reduced_datasets/10-90GeV/47k_dset1-2-3_prep_10-90GeV.hdf5'
train_ds = read_hdf5_file(train_path)
train_showers = np.array(train_ds['showers'])  # Shape: (N_total, 3, M)
train_incidents = np.array(train_ds['incident'])

# Define coordinate ranges
Xmin, Xmax = -18, 18
Ymin, Ymax = 0, 45
Zmin, Zmax = -18, 18

# Transform normalized coordinates to physical coordinates
train_showers[:, 0, :] = (train_showers[:, 0, :] + 1) * (Xmax - Xmin) / 2 + Xmin
train_showers[:, 1, :] = (train_showers[:, 1, :] + 1) * (Ymax - Ymin) / 2 + Ymin
train_showers[:, 2, :] = (train_showers[:, 2, :] + 1) * (Zmax - Zmin) / 2 + Zmin
# train_showers[:, -1, :] = train_showers[:, -1, :] * 1000  # Scale energy to GeV
plt_scatter_2(train_showers[-1])
# Randomly sample 10,000 showers (without replacement)

visible_energy = train_showers[:, 3, :][train_showers[:, 3, :] > 0]
from utils.preprocessing_utils import plt_visible_e
plt_visible_e(visible_energy, log_scale=True, title=' Before Rescaling')

train_showers[:, 0, :] = (train_showers[:, 0, :] - Xmin) / (Xmax - Xmin) * 2 - 1
train_showers[:, 1, :] = (train_showers[:, 1, :] - Ymin) / (Ymax - Ymin) * 2 - 1
train_showers[:, 2, :] = (train_showers[:, 2, :] - Zmin) / (Zmax - Zmin) * 2 - 1

n_samples = 10_000
if len(train_showers) < n_samples:
    raise ValueError(f"Dataset has only {len(train_showers)} showers, but {n_samples} were requested.")

# Randomly select indices
np.random.seed(42)
random_indices = np.random.choice(len(train_showers), size=n_samples, replace=False)
sampled_showers = train_showers[random_indices]  # Shape: (10000, 3, M)
train_incidents = train_incidents[random_indices]  # Shape: (10000,)
# Process each shower into cylindrical histograms (45 × 50 × 18) and flatten
visible_energy = sampled_showers[:, 3, :][sampled_showers[:, 3, :] > 0]
plt_visible_e(visible_energy, log_scale=True, title=' random sampled showers')

train_showers_processed = np.zeros((n_samples, 45 * 50 * 18))  # Pre-allocate array
for i in range(n_samples):
    cyl_hist = cylindrical_histogram(train_showers[i])  # Expected output: (45, 50, 18)
    hist_reshaped = cyl_hist.reshape(45 * 50 * 18)
    train_showers_processed[i] = hist_reshaped  # Flatten to (40500,)

print("Final shape:", train_showers_processed.shape)  # Output: (10000, 40500)
visible_energy = train_showers_processed[:, :][train_showers_processed[:, :] > 0]
plt_visible_e(visible_energy, log_scale=True, title='After Rescaling')
del train_showers, train_ds  # Free up memory
import gc
gc.collect()

In [None]:
import numpy as np

# Assuming you have these variables already defined:
# val_showers, train_showers_processed, val_incidents, train_incidents

# Convert to NumPy arrays and normalize showers
val_showers_norm = np.array(val_showers, dtype=np.float32) / 0.033
train_showers_norm = np.array(train_showers_processed, dtype=np.float32) / 0.033

# Create the splits (no shuffling)
split_showers = [val_showers_norm, train_showers_norm]
split_incidents = [np.array(val_incidents, dtype=np.float32), 
                   np.array(train_incidents, dtype=np.float32)]

# Create showers_baseline_np as a numpy array of splits
showers_baseline_np = np.array(split_showers)  # Shape: (2, N_samples, 40500)
incidents_baseline_np = np.array(split_incidents)

# Verify results
print("Final outputs:")
print(f"showers_baseline_np shape: {showers_baseline_np.shape}")  # Should be (2, N, 40500)
print(f"incidents_baseline_np shape: {incidents_baseline_np.shape}")

print("\nSplit details:")
print(f"Split 1 (Validation): {showers_baseline_np[0].shape} showers, {incidents_baseline_np[0].shape} incidents")
print(f"Split 2 (Training): {showers_baseline_np[1].shape} showers, {incidents_baseline_np[1].shape} incidents")

In [None]:
# showers_baseline_np = np.array([showers_numpy_baseline])
# incidents_baseline_np = np.array([incidents_numpy_baseline])
# showers_baseline_np.shape, incidents_baseline_np.shape

In [None]:
kl_divergences_baseline, wasserstein_dist_baseline = {}, {}

In [None]:
import utils.plot_evaluate as plot

# Now pass this to the function
kl_divergences_baseline, wasserstein_dist_baseline = plot.plot_visible_energy(
    showers_baseline_np,  # Use the concatenated array
    kl_divergences=kl_divergences_baseline,
    wasserstein=wasserstein_dist_baseline,
    log_scale=True
)
kl_divergences_baseline, wasserstein_dist_baseline

In [None]:
kl_divergences_baseline, wasserstein_dist_baseline = plot.plot_calibration_histograms(showers_baseline_np, incidents_baseline_np,
                                            kl_divergences=kl_divergences_baseline, wasserstein=wasserstein_dist_baseline)
kl_divergences_baseline, wasserstein_dist_baseline

In [None]:
kl_divergences_baseline, wasserstein_dist_baseline = plot.plot_energy_sum(showers_baseline_np, 
                                                                        kl_divergences=kl_divergences_baseline, 
                                                                        wasserstein=wasserstein_dist_baseline)
kl_divergences_baseline, wasserstein_dist_baseline


In [None]:
kl_divergences_baseline, wasserstein_dist_baseline = plot.plot_occupancy(showers_baseline_np,
                                                                          kl_divergences=kl_divergences_baseline, wasserstein=wasserstein_dist_baseline)
kl_divergences_baseline, wasserstein_dist_baseline

In [None]:
kl_divergences_baseline, wasserstein_dist_baseline= plot.plot_energy_layer(showers_baseline_np,
                                                             kl_divergences=kl_divergences_baseline, wasserstein=wasserstein_dist_baseline)
kl_divergences_baseline, wasserstein_dist_baseline

In [None]:
kl_divergences_baseline, wasserstein_dist_baseline = plot.plot_radial_energy(showers_baseline_np,
                                                            kl_divergences=kl_divergences_baseline, 
                                                            wasserstein=wasserstein_dist_baseline)
kl_divergences_baseline, wasserstein_dist_baseline

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


def plot_with_error_propagation(kl_divergences_baseline):
    """
    Function to plot KL divergences with error propagation and an additional plot
    for the average of averages and standard error propagation.
    
    Parameters:
    kl_divergences_baseline : list-like
        A nested list/array of KL divergence values for each dataset.
    """

    # Create a DataFrame with the input data.
    df = pd.DataFrame(kl_divergences_baseline, index=[f"Shower {i}" for i in range(2, len(kl_divergences_baseline) + 2)])

    # Compute mean and standard deviation (error propagation)
    means = df.mean()
    errors = df.std()  # Standard deviation as error bars

    # Plot means with standard deviation as error bars
    plt.figure(figsize=(14, 8))
    plt.errorbar(df.columns, means, yerr=errors, fmt='o', capsize=5, label=f"means ± Std Dev", color='b')

    # Compute the average of averages
    avg_of_averages = means.mean()

    # Calculate the overall standard deviation and the number of total samples
    total_std = means.std()
    print(f"Total Standard Deviation: {total_std:.4f}")
    N = len(df) * len(df.columns)

    # Calculate the standard error for the average of averages
    final_error = np.sqrt(total_std**2 / N)
    
    # Plot the average of averages with standard error
    plt.axhline(y=avg_of_averages, color='r', linestyle='--', label=f"Average of Averages: {avg_of_averages:.2f}")
    plt.fill_between(df.columns, avg_of_averages - final_error, avg_of_averages + final_error, color='r', alpha=0.2, label="Overall SEM")

    # Formatting
    plt.xticks(rotation=45)
    plt.ylabel("Value")
    plt.yscale('log')
    plt.title("Average Plot with Error Propagation and Average of Averages")
    plt.legend()
    plt.grid(True)

    # Show plot
    plt.show()


In [None]:
kl_divergences_baseline

In [None]:
df_kl = plot.plot_dataframe(kl_divergences_baseline, 'KL Divergence')
plot_with_error_propagation(kl_divergences_baseline)


In [None]:
df_wass = plot.plot_dataframe(wasserstein_dist_baseline,'Wasserstein distance')
plot_with_error_propagation(wasserstein_dist_baseline)

In [None]:
wasserstein_dist_baseline