In [None]:
%load_ext autoreload
%autoreload 2
import sys
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colormaps
import matplotlib.gridspec as gridspec
from tqdm.notebook import tqdm  # Use notebook version for better display
import os
import json
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

# Custom imports - make sure these are in your path
import utils.plot_evaluate as plot
from utils.preprocessing_utils import read_hdf5_file2

In [None]:
plt.rcParams.update({
        # Use a serif font that's likely available
        'font.family': 'serif',
        'font.serif': ['DejaVu Serif', 'Liberation Serif', 'Computer Modern Roman', 'Bitstream Vera Serif'],
        'font.size': 12,
        'axes.labelsize': 14,
        'axes.titlesize': 16,
        'xtick.labelsize': 12,
        'ytick.labelsize': 12,
        'legend.fontsize': 12,
        'figure.dpi': 300,
        'savefig.dpi': 600,  # Higher DPI for publication quality
        'savefig.format': 'pdf',  # PDF format is often preferred for publications
        'savefig.bbox': 'tight',
        'savefig.pad_inches': 0.1,
        'axes.linewidth': 0.8,  # Slightly thinner axes lines
        'lines.linewidth': 1.5,  # Slightly thicker plot lines
        'lines.markersize': 4,  # Slightly smaller markers
        # 'axes.grid': True,
        'grid.alpha': 0.3
    })

# final distribution plots for the paper

## complete models ((diffusion + showerflow))

In [None]:
# ========================================
# USER CONFIGURATION SECTION - FIXED VERSION
# ========================================

# Import paths from the cleaned paths file
from utils.paths_trainings_cleaned import (
    PW_ENTRIES, 
    BASE_PATH_SHOWERS, 
    GEANT4_PATH,
    DISPLAY_NAME_MAP,
    build_showers_paths
)

# Base paths
BASE_PATH = BASE_PATH_SHOWERS
RESULTS_DIR = './results/final_comparison'

# Dataset sizes to include in final plots (must match keys in PW_ENTRIES)
SELECTED_DATASET_SIZES = ['100_1-1000', '1k_1-1000', '10k_1-1000', '100k_1-1000']

# Dataset size order and colors for plotting
dataset_sizes_order = ['100', '1,000', '10,000', '100,000']
dataset_colors = {
    '100': '#3C493F',      
    '1,000': '#6320EE',     
    '10,000': '#5C9EAD',   
    '100,000': '#EF767A',  
}

# Map dataset keys to the format expected by SHOWER_PATHS
DATASET_TO_DISPLAY_MAP = {
    '100_1-1000': 'D = 1 x 10^2',   # Note: must match SHOWER_PATHS keys exactly
    '1k_1-1000': 'D = 1 x 10^3',
    '10k_1-1000': 'D = 1 x 10^4',
    '100k_1-1000': 'D = 1 x 10^5'
}

# Fix DISPLAY_NAME_SIMPLE to map to simple internal storage keys
DISPLAY_NAME_SIMPLE = {
    'D = 1 x 10^2': '100',
    'D = 1 x 10^3': '1,000',
    'D = 1 x 10^4': '10,000',
    'D = 1 x 10^5': '100,000'
}

# Plot display labels - use LaTeX for better compatibility
PLOT_DISPLAY_LABELS = {
    '100': r'$D = 10^{2}$',        # 100 samples
    '1,000': r'$D = 10^{3}$',      # 1,000 samples  
    '10,000': r'$D = 10^{4}$',     # 10,000 samples
    '100,000': r'$D = 10^{5}$'     # 100,000 samples
}

# Alternative: If you want to show the actual sample count
PLOT_DISPLAY_LABELS_WITH_COUNT = {
    '100': r'100 samples',
    '1,000': r'1k samples',
    '10,000': r'10k samples',
    '100,000': r'100k samples'
}

# Alternative: More descriptive labels
PLOT_DISPLAY_LABELS_DESCRIPTIVE = {
    '100': r'$N = 10^{2}$ training samples',
    '1,000': r'$N = 10^{3}$ training samples',
    '10,000': r'$N = 10^{4}$ training samples',
    '100,000': r'$N = 10^{5}$ training samples'
}

# ========================================
# MANUALLY SELECT OPTIMAL TRAINING STEPS HERE
# ========================================
OPTIMAL_TRAINING_STEPS = {
    'vanilla_full_v1_1_1000': {
        '100_1-1000': 250_000,    # Updated based on your output
        '1k_1-1000': 1_000_000,   # Updated based on your output
        '10k_1-1000': 500_000,    # Updated based on your output
        '100k_1-1000': 750_000,   # Updated based on your output
    },
    'finetune_full_v1_1_1000': {
        '100_1-1000': 100_000,
        '1k_1-1000': 50_000,
        '10k_1-1000': 100_000,
        '100k_1-1000': 250_000,
    },
    'finetune_top3_v1_1_1000': {
        '100_1-1000': 1_000_000,
        '1k_1-1000': 500_000,
        '10k_1-1000': 750_000,
        '100k_1-1000': 750_000,
    },
    'finetune_bitfit_v1_1_1000': {
        '100_1-1000': 1_000_000,
        '1k_1-1000': 750_000,
        '10k_1-1000': 500_000,
        '100k_1-1000': 500_000,
    },
    'lora_r8_v1_1_1000': {
        '100_1-1000': 250_000,
        '1k_1-1000': 10_000,
        '10k_1-1000': 100_000,
        '100k_1-1000': 200_000,
    },
    'lora_r106_v1_1_1000': {
        '100_1-1000': 100_000,
        '1k_1-1000': 100_000,
        '10k_1-1000': 10_000,
        '100k_1-1000': 50_000,
    }
}

# Build the proper shower paths using the function from paths_trainings_cleaned
selected_strategies = [
    'vanilla_full_v1_1_1000', 
    'finetune_full_v1_1_1000',
    'finetune_top3_v1_1_1000',
    'finetune_bitfit_v1_1_1000',
    'lora_r8_v1_1_1000',
    'lora_r106_v1_1_1000',

]
selected_pw_entries = {k: v for k, v in PW_ENTRIES.items() if k in selected_strategies}
SHOWER_PATHS = build_showers_paths(BASE_PATH, selected_pw_entries, GEANT4_PATH, DISPLAY_NAME_MAP)

# EMA configuration
USE_EMA = True
EMA_SUFFIX = '' if USE_EMA else '_no_ema'

# GEANT4 reference path
GEANT4_PATH_USED = GEANT4_PATH

# Strategy display names
strategy_names = {
    'vanilla_full_v1_1_1000': 'From scratch',
    'finetune_full_v1_1_1000': 'Full fine-tuned',
    'finetune_top3_v1_1_1000': 'Top-3 fine-tuned',
    'finetune_bitfit_v1_1_1000': 'BitFit',
    'lora_r8_v1_1_1000': 'LoRA R8',
    'lora_r106_v1_1_1000': 'LoRA R106',

}

print("Configuration loaded successfully")
print(f"Selected strategies: {selected_strategies}")
print(f"Dataset sizes: {SELECTED_DATASET_SIZES}")
print("Available shower paths:")
for strategy, paths in SHOWER_PATHS.items():
    print(f"  {strategy}: {list(paths.keys())}")

### 3. Helper Functions

In [None]:
class ShowerAnalysisNotebook:
    """Notebook-friendly shower analysis class."""
    
    def __init__(self, base_path: str, energy_scaling: float = 0.033):
        self.base_path = base_path
        self.energy_scaling = energy_scaling
        
    def load_shower_data(self, path: str) -> Tuple[np.ndarray, np.ndarray]:
        """Load shower data from HDF5 file."""
        if not os.path.exists(path):
            raise FileNotFoundError(f"File not found: {path}")
            
        dataset_names, incidents, showers = read_hdf5_file2(path)
        return incidents.squeeze(), showers
    
    def load_geant4_reference(self, path: str) -> Tuple[np.ndarray, np.ndarray]:
        """Load GEANT4 reference data."""
        dataset_names, incidents, showers = read_hdf5_file2(path)
        showers /= self.energy_scaling  # Apply energy scaling to GEANT4
        return incidents.squeeze(), showers

### 4. Load and process Data

In [None]:
# Initialize analyzer
analyzer = ShowerAnalysisNotebook(base_path=BASE_PATH)

# Storage for all shower data
all_shower_data = {}
all_incident_data = {}

# Load GEANT4 reference data first
print("Loading GEANT4 reference data...")
geant4_incidents, geant4_showers = analyzer.load_geant4_reference(GEANT4_PATH)
print(f"GEANT4 data loaded: {geant4_showers.shape}")

# Progress tracking - only for the strategies we want
total_configs = len(selected_strategies) * len(SELECTED_DATASET_SIZES)
progress_bar = tqdm(total=total_configs, desc="Loading data")

# Process only the selected strategies
for strategy_name in selected_strategies:
    if strategy_name not in SHOWER_PATHS:
        print(f"WARNING: Strategy {strategy_name} not found in SHOWER_PATHS")
        continue
    
    print(f"\n{'='*60}")
    print(f"Processing strategy: {strategy_name}")
    print(f"{'='*60}")
    
    # Store shower and incident data for this strategy
    all_shower_data[strategy_name] = {}
    all_incident_data[strategy_name] = {}
    
    # Process each dataset size
    for dataset_key in SELECTED_DATASET_SIZES:
        optimal_step = OPTIMAL_TRAINING_STEPS[strategy_name].get(dataset_key)
        
        if optimal_step is None:
            print(f"  WARNING: No optimal step for {dataset_key} in {strategy_name}")
            progress_bar.update(1)
            continue
        
        # Get the scientific notation display name used by SHOWER_PATHS
        scientific_display_name = DATASET_TO_DISPLAY_MAP[dataset_key]
        
        # Get the path template from SHOWER_PATHS and substitute the training step
        if scientific_display_name not in SHOWER_PATHS[strategy_name]:
            print(f"  WARNING: Display size {scientific_display_name} not found in SHOWER_PATHS for {strategy_name}")
            progress_bar.update(1)
            continue
            
        shower_path_template = SHOWER_PATHS[strategy_name][scientific_display_name]
        shower_path = shower_path_template.format(training_step=optimal_step)
        
        # Add the /showers.hdf5 suffix
        shower_path = f"{shower_path}/showers.hdf5"
        
        print(f"  Dataset {dataset_key}: step {optimal_step:,}")
        print(f"    Path: {shower_path}")
        
        try:
            # Load generated showers
            gen_incidents, gen_showers = analyzer.load_shower_data(shower_path)
            
            # Store with simple display name
            simple_display_name = DISPLAY_NAME_SIMPLE[scientific_display_name]
            all_shower_data[strategy_name][simple_display_name] = gen_showers
            all_incident_data[strategy_name][simple_display_name] = gen_incidents
            
            print(f"    Loaded shower data: {gen_showers.shape}")
            
        except Exception as e:
            print(f"    ERROR: {str(e)}")
            
        progress_bar.update(1)

progress_bar.close()

print("\n" + "="*60)
print("Data loading summary:")
print(f"GEANT4: {geant4_showers.shape}")
for strategy, data in all_shower_data.items():
    if isinstance(data, dict):
        print(f"{strategy}: {list(data.keys())}")
    else:
        print(f"{strategy}: {type(data)}")

# ============================================================================
# CREATE DISTRIBUTION PLOTS FOR EACH STRATEGY WITH ALL DATASET SIZES TOGETHER
# ============================================================================

import gc  # For garbage collection
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend to save memory

# Dataset size order for plotting (consistent order)
dataset_sizes_order = ['100,000', '10,000', '1,000', '100']


print(f"\n{'='*60}")
print("GENERATING DISTRIBUTION PLOTS")
print(f"{'='*60}")

# Create output directories for both strategies
for strategy_name in selected_strategies:
    strategy_display_name = strategy_names[strategy_name]
    save_dir = f'{RESULTS_DIR}/{strategy_display_name.lower()}'
    os.makedirs(save_dir, exist_ok=True)
    print(f"Created directory: {save_dir}")

# Process each strategy separately to manage memory
for strategy_idx, strategy_name in enumerate(selected_strategies):
    if strategy_name not in all_shower_data:
        print(f"No data for strategy {strategy_name}")
        continue
        
    print(f"\nGenerating plots for {strategy_names[strategy_name]} strategy...")
    
    # Prepare data for plotting
    showers_list = [geant4_showers]  # Start with GEANT4
    incidents_list = [geant4_incidents]
    labels = ['Geant4']
    colors = ['gray']  # GEANT4 in gray
    
    # Add each dataset size
    for dataset_size in dataset_sizes_order:
        if dataset_size in all_shower_data[strategy_name]:
            showers_list.append(all_shower_data[strategy_name][dataset_size])
            incidents_list.append(all_incident_data[strategy_name][dataset_size])
            
            # Use the display label for the plot
            labels.append(PLOT_DISPLAY_LABELS[dataset_size])
            colors.append(dataset_colors[dataset_size])
    
    if len(showers_list) <= 1:
        print(f"Not enough data for {strategy_name}")
        continue
    
    # Convert to numpy arrays
    showers_numpy = np.array(showers_list)
    incidents_numpy = np.array(incidents_list)
    
    # Create simulation labels dict
    simulation_labels = {label: None for label in labels}
    
    print(f"Plotting {len(showers_list)} datasets: {labels}")
    
    # Create each type of plot
    plot_configurations = [
        (plot.plot_calibration_histograms, True, 'calibration'),
        (plot.plot_energy_sum, False, 'energy_sum'),
        (plot.plot_occupancy, False, 'occupancy'),
        (plot.plot_energy_layer, False, 'energy_layer'),
        (plot.plot_radial_energy, False, 'radial_energy'),
        (plot.plot_visible_energy, False, 'visible_energy'),
    ]
    
    # Set save directory for this strategy
    save_dir = f'{RESULTS_DIR}/{strategy_names[strategy_name].lower()}'
    
    for plot_idx, (plot_func, needs_incidents, plot_name) in enumerate(plot_configurations):
        try:
            print(f"  Creating {plot_name} plot...")
            
            args = [showers_numpy]
            if needs_incidents:
                args.append(incidents_numpy)
            
            # Call plot function
            kl_divergences, wasserstein_dist = plot_func(
                *args,
                simulation_labels=simulation_labels,
                colors=colors,
                kl_divergences={},
                wasserstein={},
                training_strategy=strategy_names[strategy_name],
                save_plot=True,
                save_dir=save_dir
            )
            
            # Clear matplotlib cache after each plot
            plt.close('all')
            
            # Force garbage collection every 2 plots to free memory
            if plot_idx % 2 == 1:
                gc.collect()
                
        except Exception as e:
            print(f"    ERROR creating {plot_name}: {str(e)}")
            # Close any remaining plots and clear memory
            plt.close('all')
            gc.collect()
    
    # After completing all plots for this strategy, clear memory
    print(f"  Completed all plots for {strategy_names[strategy_name]}")
    
    # Clear the strategy data from memory after plotting (except GEANT4)
    if strategy_idx == 0:  # Keep data for first strategy in case needed
        pass
    else:  # Clear data for subsequent strategies
        if strategy_name in all_shower_data:
            del all_shower_data[strategy_name]
        if strategy_name in all_incident_data:
            del all_incident_data[strategy_name]
    
    # Clear numpy arrays and force garbage collection
    del showers_numpy, incidents_numpy
    plt.close('all')
    gc.collect()
    
    print(f"  Memory cleared for {strategy_names[strategy_name]} strategy")

print(f"\n{'='*60}")
print("ALL PLOTS GENERATED!")
print(f"{'='*60}")

# Final memory cleanup
plt.close('all')
gc.collect()
print("Final memory cleanup completed")

### composition of generated plots

In [None]:
selected_strategies = [
    'vanilla_full_v1_1_1000', 
    # 'finetune_full_v1_1_1000',
    # 'finetune_top3_v1_1_1000',
    # 'finetune_bitfit_v1_1_1000',
    # 'lora_r8_v1_1_1000',
    # 'lora_r106_v1_1_1000',

]

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.gridspec as gridspec
import os
import numpy as np
from PIL import Image

def combine_plots_for_paper(strategy_name, source_dir, output_dir=None, output_format='png'):
    """
    Combines individual plots into two composite figures for publication.
    
    Parameters:
    -----------
    strategy_name : str
        Name of the training strategy (e.g., 'from scratch', 'lora r8', 'bitfit')
    source_dir : str
        Directory containing the individual plot files
    output_dir : str, optional
        Directory to save the composite figures (defaults to source_dir)
    output_format : str, optional
        Output format ('png' or 'pdf'), defaults to 'png'
    """
    
    if output_dir is None:
        output_dir = source_dir
    
    # Determine output format based on strategy
    if strategy_name.lower() in ['from scratch', 'full fine-tuned']:
        format_ext = 'png'
    else:
        format_ext = 'pdf'
    
    # Override if explicitly specified
    if output_format:
        format_ext = output_format
    
    # Define the plot files for each composite figure
    # Figure 1: Energy-related plots
    figure1_plots = [
        'Voxel_Energy_Spectrum_with_ratio.png',
        'Sampling_Fraction_with_ratio.png', 
        'Visible_Energy_with_ratio.png'
    ]
    
    # Figure 2: Spatial distribution plots
    figure2_plots = [
        'Occupancy_with_ratio.png',
        'Longitudinal_Profile_with_ratio.png',
        'Radial_Profile_with_ratio.png'
    ]
    
    # Create composite figures
    create_composite_figure(figure1_plots, source_dir, output_dir, 
                           f'{strategy_name.lower().replace(" ", "_")}_energy_plots.{format_ext}', 
                           strategy_name, format_ext)
    
    create_composite_figure(figure2_plots, source_dir, output_dir,
                           f'{strategy_name.lower().replace(" ", "_")}_spatial_plots.{format_ext}',
                           strategy_name, format_ext)
    
    print(f"Composite figures created for {strategy_name} strategy:")
    print(f"  - {output_dir}/{strategy_name.lower().replace(' ', '_')}_energy_plots.{format_ext}")
    print(f"  - {output_dir}/{strategy_name.lower().replace(' ', '_')}_spatial_plots.{format_ext}")

def combine_all_strategies(base_dir, strategies, output_dir=None):
    """
    Combines plots for all strategies with flexible strategy handling.
    
    Parameters:
    -----------
    base_dir : str
        Base directory containing strategy subdirectories
    strategies : list or dict
        List of strategy names or dict mapping strategy_key -> display_name
    output_dir : str, optional
        Output directory for composite figures
    """
    
    if output_dir is None:
        output_dir = os.path.join(base_dir, 'composite_figures')
    
    os.makedirs(output_dir, exist_ok=True)
    
    # Handle both list and dict inputs
    if isinstance(strategies, dict):
        strategy_items = strategies.items()
    else:
        strategy_items = [(s, s) for s in strategies]
    
    for strategy_key, display_name in strategy_items:
        # Convert display name to directory name (lowercase, replace spaces with underscores)
        dir_name = display_name.lower().replace(' ', '_').replace('-', '_')
        strategy_dir = os.path.join(base_dir, dir_name)
        
        if os.path.exists(strategy_dir):
            print(f"\nProcessing {display_name} strategy...")
            combine_plots_for_paper(display_name, strategy_dir, output_dir)
        else:
            print(f"Warning: Directory {strategy_dir} not found!")
            # Try alternative directory naming
            alt_dir_name = display_name.lower()
            alt_strategy_dir = os.path.join(base_dir, alt_dir_name)
            if os.path.exists(alt_strategy_dir):
                print(f"  Found alternative directory: {alt_strategy_dir}")
                combine_plots_for_paper(display_name, alt_strategy_dir, output_dir)
            else:
                print(f"  Alternative directory {alt_strategy_dir} also not found!")
    
    print(f"\nAll composite figures saved to: {output_dir}")

def create_comparison_figure(strategy_dirs, output_dir, plot_type='energy', strategy_names=None):
    """
    Creates a comparison figure showing multiple strategies side by side.
    
    Parameters:
    -----------
    strategy_dirs : dict
        Dictionary mapping strategy_name -> directory_path
    output_dir : str
        Output directory for comparison figure
    plot_type : str
        Either 'energy' or 'spatial' to determine which plots to compare
    strategy_names : dict, optional
        Mapping of directory keys to display names
    """
    
    if plot_type == 'energy':
        plot_files = [
            'Voxel_Energy_Spectrum_with_ratio.png',
            'Sampling_Fraction_with_ratio.png',
            'Visible_Energy_with_ratio.png'
        ]
        output_name = 'comparison_energy_plots.pdf'
    else:  # spatial
        plot_files = [
            'Occupancy_with_ratio.png',
            'Longitudinal_Profile_with_ratio.png',
            'Radial_Profile_with_ratio.png'
        ]
        output_name = 'comparison_spatial_plots.pdf'
    
    num_strategies = len(strategy_dirs)
    num_plots = len(plot_files)
    
    # Create grid (num_strategies rows, num_plots columns)
    fig = plt.figure(figsize=(8*num_plots, 6*num_strategies))
    
    for row, (strategy_key, source_dir) in enumerate(strategy_dirs.items()):
        # Get display name
        if strategy_names and strategy_key in strategy_names:
            display_name = strategy_names[strategy_key]
        else:
            display_name = strategy_key.replace('_', ' ').title()
            
        for col, plot_file in enumerate(plot_files):
            filepath = os.path.join(source_dir, plot_file)
            if os.path.exists(filepath):
                img = Image.open(filepath)
                ax = plt.subplot(num_strategies, num_plots, row*num_plots + col + 1)
                ax.imshow(img)
                ax.axis('off')
                
                # Add strategy label for leftmost plots
                if col == 0:
                    ax.set_ylabel(display_name, fontsize=16, fontweight='bold')
                    ax.yaxis.set_label_coords(-0.05, 0.5)
            else:
                print(f"Warning: {filepath} not found!")
    
    plt.subplots_adjust(left=0.05, right=0.99, top=0.98, bottom=0.02, 
                       wspace=0.02, hspace=0.02)
    
    # Save
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, output_name)
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    plt.close()
    
    print(f"Comparison figure saved: {output_path}")

def create_composite_figure(plot_files, source_dir, output_dir, output_filename, strategy_name, output_format='png'):
    """
    Creates a composite figure from individual plot files.
    
    Parameters:
    -----------
    plot_files : list
        List of plot filenames to combine
    source_dir : str
        Directory containing the plot files
    output_dir : str
        Directory to save the composite figure
    output_filename : str
        Name of the output file
    strategy_name : str
        Strategy name for error handling
    output_format : str
        Output format ('png' or 'pdf')
    """
    
    # Load images
    images = []
    for plot_file in plot_files:
        filepath = os.path.join(source_dir, plot_file)
        if os.path.exists(filepath):
            img = Image.open(filepath)
            images.append(np.array(img))
        else:
            print(f"Warning: {filepath} not found!")
            return
    
    # Get dimensions
    heights = [img.shape[0] for img in images]
    widths = [img.shape[1] for img in images]
    
    # Use the maximum height and sum of widths
    max_height = max(heights)
    total_width = sum(widths)
    
    # Create figure with proper sizing
    # Adjust figure size to maintain aspect ratios
    fig_width = 24  # inches
    fig_height = fig_width * max_height / total_width * 0.95  # Slightly reduce height for better proportions
    
    fig = plt.figure(figsize=(fig_width, fig_height))
    
    # Create subplots
    for i, img in enumerate(images):
        ax = plt.subplot(1, 3, i+1)
        ax.imshow(img)
        ax.axis('off')
    
    # Adjust spacing
    plt.subplots_adjust(left=0.01, right=0.99, top=0.99, bottom=0.01, wspace=0.02)
    
    # Save the composite figure
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, output_filename)
    
    # Set DPI based on format
    dpi = 300 if output_format.lower() == 'pdf' else 150
    
    plt.savefig(output_path, dpi=dpi, bbox_inches='tight', pad_inches=0.1, format=output_format)
    plt.close()
    
    print(f"  Saved: {output_path}")

def create_all_composite_figures(results_dir, strategy_names=None, selected_strategies=None):
    """
    Main function to create all composite figures for all strategies found in directory.
    All composite figures will be saved in PDF format to the composite_figures subdirectory.
    
    Parameters:
    -----------
    results_dir : str
        Base directory containing all strategy results
    strategy_names : dict, optional
        Dictionary mapping strategy keys to display names (not used in this version)
    selected_strategies : list, optional
        List of strategy keys to process (if None, processes all directories)
    """
    
    print(f"{'='*60}")
    print("CREATING COMPOSITE FIGURES")
    print(f"{'='*60}")
    
    # Create composite_figures directory
    composite_dir = os.path.join(results_dir, 'composite_figures')
    os.makedirs(composite_dir, exist_ok=True)
    
    # Find all directories in results_dir (use actual directory names)
    all_dirs = [d for d in os.listdir(results_dir) 
                if os.path.isdir(os.path.join(results_dir, d)) 
                and d not in ['comparisons', 'composite_figures']]
    
    print(f"Found directories: {all_dirs}")
    print(f"Saving all composite figures to: {composite_dir}")
    
    # Process each directory
    for dir_name in all_dirs:
        strategy_dir = os.path.join(results_dir, dir_name)
        
        # Use directory name as display name
        display_name = dir_name
        
        print(f"\nCreating composite figures for: {display_name}")
        
        # Check if composite figures already exist in composite_figures directory
        safe_name = dir_name.replace(' ', '_').replace('-', '_')
        pdf_energy = os.path.join(composite_dir, f"{safe_name}_energy_plots.pdf")
        pdf_spatial = os.path.join(composite_dir, f"{safe_name}_spatial_plots.pdf") 
        
        if os.path.exists(pdf_energy) and os.path.exists(pdf_spatial):
            print(f"  Composite figures already exist for {display_name}, skipping...")
            continue
        
        # Always use PDF format and save to composite_figures directory
        try:
            combine_plots_for_paper(display_name, strategy_dir, 
                                   output_dir=composite_dir, 
                                   output_format='pdf')
        except Exception as e:
            print(f"  Error creating composite figures for {display_name}: {e}")
    
    print(f"\n{'='*60}")
    print("COMPOSITE FIGURES COMPLETED!")
    print(f"All composite figures saved to: {composite_dir}")
    print(f"{'='*60}")

# Example usage - simple and clean
if __name__ == "__main__":
    # Set your results directory
    RESULTS_DIR = './results/final_comparison'
    
    # Create all composite figures (will be saved as PDFs in composite_figures directory)
    create_all_composite_figures(RESULTS_DIR)

# Showerflow

In [None]:
# ShowerFlow Cluster Distribution Analysis
import os
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from typing import Dict, Tuple, Optional, List
import json

# Import project modules
import utils.eval_utils_showerflow as eval_utils
from utils.preprocessing_utils import read_hdf5_file2
from configs import Configs

# Configuration
class AnalysisConfig:
    def __init__(self):
        self.pretrained = False # Set to True for finetuned models, False for vanilla
        self.energy_range = '1-1000GeV'
        self.selected_models = ['100k', '10k', '1k', '100']
        self.display_layers = np.linspace(0, 41, 8).astype(int)
        # 8 layers for 2x4 grid
        self.output_dir = './results/for_paper/showerflow_analysis/appendix'
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Plot configuration
PLOT_DISPLAY_LABELS = {
    '100': r'$D = 10^{2}$', '1,000': r'$D = 10^{3}$', 
    '10,000': r'$D = 10^{4}$', '100,000': r'$D = 10^{5}$'
}

dataset_colors = {
    '100': '#3C493F', '1,000': '#6320EE', 
    '10,000': '#5C9EAD', '100,000': '#EF767A'
}

DISPLAY_NAME_TO_SIMPLE = {
    'D = 1 x 10^2': '100', 'D = 1 x 10^3': '1,000',
    'D = 1 x 10^4': '10,000', 'D = 1 x 10^5': '100,000'
}

config = AnalysisConfig()
os.makedirs(config.output_dir, exist_ok=True)

def extract_sf_paths_from_entries(config: AnalysisConfig) -> Tuple[Dict[str, str], Dict[str, int]]:
    """Extract ShowerFlow paths from paths_training_cleaned.py"""
    try:
        from utils.paths_trainings_cleaned import SF_ENTRIES, BASE_PATH_SF, DATASET_DIR_MAP, DISPLAY_NAME_MAP
    except ImportError:
        import sys
        sys.path.append('./utils')
        from paths_trainings_cleaned import SF_ENTRIES, BASE_PATH_SF, DATASET_DIR_MAP, DISPLAY_NAME_MAP
    
    training_type = 'finetune' if config.pretrained else 'vanilla'
    sf_key = f'{training_type}_1-1000' if '1-1000' in config.energy_range else f'{training_type}_10-90'
    
    if sf_key not in SF_ENTRIES:
        raise ValueError(f"Key {sf_key} not found in SF_ENTRIES")
    
    entries = SF_ENTRIES[sf_key]
    model_paths = {}
    model_epochs = {}
    
    simplified_to_ds = {
        '100k': '100k_1-1000', '10k': f'10k_{config.energy_range.replace("GeV", "").lower()}',
        '1k': f'1k_{config.energy_range.replace("GeV", "").lower()}',
        '100': f'100_{config.energy_range.replace("GeV", "").lower()}'
    }
    
    for ds_key, date_folder, best_epoch in entries:
        for simplified_name in config.selected_models:
            expected_key = simplified_to_ds.get(simplified_name)
            if expected_key and ds_key == expected_key:
                display_name = DISPLAY_NAME_MAP.get(ds_key, ds_key)
                dir_name = DATASET_DIR_MAP.get(ds_key, ds_key)
                path = f"{BASE_PATH_SF}/Shower_flow_weights/{dir_name}/{training_type}/ShowerFlow_{date_folder}/ShowerFlow_{best_epoch}.pth"
                model_paths[display_name] = path
                model_epochs[display_name] = int(best_epoch)
    
    return model_paths, model_epochs

def load_original_data(energy_range: str = '1-1000GeV', dim: str = '10k'):
    """Load the original CaloChallenge data"""
    from utils.paths_configs import sf_eval_paths as dataset_paths
    eval_ds = '1-1000GeV' if '1-1000' in energy_range else '10-90GeV'
    
    if eval_ds not in dataset_paths or dim not in dataset_paths[eval_ds]:
        raise ValueError(f"Invalid dataset: {eval_ds}, {dim}")
    
    paths = dataset_paths[eval_ds][dim]
    keys, energy, events = read_hdf5_file2(paths['data_path'])
    clusters_per_layer = np.load(paths['clusters_per_layer_path'])
    return energy, clusters_per_layer

def plot_cluster_distributions(original_clusters: pd.Series, samples_dict: Dict[str, np.ndarray], 
                              display_layers: list, norm: float = 1.0, save_path: Optional[str] = None):
    """Plot cluster distributions for selected layers in a 2x4 grid with adaptive binning"""
    fig, axes = plt.subplots(2, 4, figsize=(24, 12))
    axes = axes.flatten()
    
    pretrained = config.pretrained
    training_strategy = 'Fine-tuned' if pretrained else 'From scratch'
    wasserstein_distances = {model: [] for model in samples_dict}
    
    for idx, layer in enumerate(display_layers):
        ax = axes[idx]
        original_layer_data = original_clusters.apply(lambda x: x[layer])
        
        # Calculate adaptive bins based on data range
        max_val = original_layer_data.max()
        # Add check for generated samples max too
        for samples in samples_dict.values():
            sample_max = np.round(samples[:, 2 + layer] * norm).max()
            max_val = max(max_val, sample_max)
        
        # Adaptive binning: aim for ~8-12 units per bin, constrain between 20-80 bins
        target_width = 18
        n_bins = max(15, min(50, int(max_val / target_width)))
        
        h = ax.hist(original_layer_data, bins=n_bins, color='lightgrey', label='GEANT4', alpha=0.7)
        
        for model_name, samples in samples_dict.items():
            layer_sample_data = samples[:, 2 + layer] * norm
            layer_sample_data = np.round(layer_sample_data)
            
            simple_key = DISPLAY_NAME_TO_SIMPLE.get(model_name, model_name)
            color = dataset_colors.get(simple_key, 'black')
            plot_label = PLOT_DISPLAY_LABELS.get(simple_key, model_name)
            
            counts, bin_edges = np.histogram(layer_sample_data, bins=h[1])
            
            # Create extended arrays that start at x=0
            # Prepend zero to ensure the line starts at (0, 0)
            extended_edges = np.concatenate([[0], bin_edges])
            extended_counts = np.concatenate([[0], counts, [0]])  # Add zero at end too for clean closure
            
            # Use stairs instead of step for cleaner control
            ax.stairs(extended_counts[:-1], extended_edges, linewidth=3, color=color, label=plot_label)
            
            # Alternative: if you prefer to use step, you can do:
            # bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
            # extended_centers = np.concatenate([[0], bin_centers])
            # extended_counts = np.concatenate([[0], counts])
            # ax.step(extended_centers, extended_counts, where='mid', linewidth=2, color=color, label=plot_label)

            if idx == 0:
                ax.set_title(f'{training_strategy}', loc='right', fontsize=30, pad=20, weight='bold')

            from scipy.stats import wasserstein_distance
            wd = wasserstein_distance(original_layer_data, layer_sample_data)
            wasserstein_distances[model_name].append(wd)
        
        ax.set_xlabel(f'Points in Layer {layer + 1}', fontsize=24)
        ax.set_yscale('log')
        ax.set_xlim(0, 670)
        ax.set_ylim(1e1, 1e4)
        ax.tick_params(axis='x', labelsize=24)
        ax.tick_params(axis='y', labelsize=24)
        
        if idx == 0:
            ax.legend(loc='upper right', fontsize=24, frameon=False)
    
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved figure to {save_path}")
    plt.show()
    
    return {model: np.mean(distances) for model, distances in wasserstein_distances.items()}


# Extract paths and load data
model_paths, model_epochs = extract_sf_paths_from_entries(config)
energy, original_clusters_per_layer = load_original_data(energy_range=config.energy_range)

print(f"Using {'finetuned' if config.pretrained else 'vanilla'} models for {config.energy_range}")
print(f"Loaded data: energy shape {energy.shape}, clusters shape {original_clusters_per_layer.shape}")

# Load models and generate samples
_, distributions, model_name_to_index = eval_utils.load_models_and_distributions(
    model_paths, config.device, num_blocks=2, num_inputs=92
)

cfg = Configs()
samples_dict = eval_utils.generate_samples(
    distributions, model_name_to_index, cfg, energy, config.device,
    min_energy=energy.min(), max_energy=energy.max()
)

# Generate plots
clusters_series = pd.Series(list(original_clusters_per_layer))
training_type = 'Fine-tuned' if config.pretrained else 'Vanilla'
save_path = os.path.join(config.output_dir, 
                        f"cluster_distributions_{training_type.lower()}_{config.energy_range.replace('GeV', '')}.pdf")

mean_distances = plot_cluster_distributions(
    original_clusters=clusters_series, samples_dict=samples_dict,
    display_layers=config.display_layers, norm=cfg.sf_norm_points, save_path=save_path
)

# Print summary
print(f"\n{'='*60}\nANALYSIS SUMMARY\n{'='*60}")
print(f"Training Type: {training_type}")
print(f"Energy Range: {config.energy_range}")
print(f"Models Analyzed: {list(samples_dict.keys())}")
print("Mean Wasserstein Distances:")
for model, distance in mean_distances.items():
    display_label = PLOT_DISPLAY_LABELS.get(DISPLAY_NAME_TO_SIMPLE.get(model, model), model)
    print(f"  {display_label}: {distance:.4f}")