In [None]:
import h5py
import numpy as np
from tqdm import tqdm

def create_progressive_heatmaps(input_filepath, output_filepath):
    """
    Create progressive average heatmaps from noisy realizations.
    
    For each config with N=70 realizations, creates 70 progressive averages:
    - heatmap_1 = image_1
    - heatmap_2 = (image_1 + image_2) / 2
    - heatmap_3 = (image_1 + image_2 + image_3) / 3
    - ...
    - heatmap_70 = (image_1 + ... + image_70) / 70
    
    Parameters
    ----------
    input_filepath : str
        Path to original HDF5 file with noisy realizations
    output_filepath : str
        Path to save progressive heatmaps HDF5 file
    """
    
    with h5py.File(input_filepath, 'r') as f_in, \
         h5py.File(output_filepath, 'w') as f_out:
        
        # Copy metadata
        metadata_out = f_out.create_group('metadata')
        for key, val in f_in['metadata'].attrs.items():
            metadata_out.attrs[key] = val
        
        # Copy axes if they exist
        if 'axes' in f_in:
            axes_out = f_out.create_group('axes')
            for axis_name in f_in['axes'].keys():
                axes_out.create_dataset(axis_name, data=f_in['axes'][axis_name][:])
        
        # Get number of configs
        n_configs = f_in['metadata'].attrs['n_configs']
        n_realizations = f_in['metadata'].attrs['n_realizations_per_config']
        
        # Create groups for output
        params_out = f_out.create_group('parameter_configs')
        heatmaps_out = f_out.create_group('progressive_heatmaps')
        
        print(f"Processing {n_configs} configs with {n_realizations} realizations each...")
        print(f"Total progressive heatmaps: {n_configs * n_realizations}")
        
        # Process each configuration
        for config_idx in tqdm(range(n_configs), desc="Configs"):
            config_name = f'config_{config_idx:06d}'
            
            # Copy parameter configuration
            config_in = f_in['parameter_configs'][config_name]
            config_out = params_out.create_group(config_name)
            for key, val in config_in.attrs.items():
                config_out.attrs[key] = val
            
            # Load all noisy realizations for this config
            data_in = f_in['simulated_data'][config_name]
            noisy_realizations = []
            for real_idx in range(n_realizations):
                real_key = f'{real_idx:06d}'
                noisy_realizations.append(data_in[real_key][:])
            
            noisy_realizations = np.array(noisy_realizations)  # Shape: (70, time, energy)
            
            # Create progressive heatmaps group for this config
            heatmap_group = heatmaps_out.create_group(config_name)
            
            # Compute progressive averages
            cumsum = np.zeros_like(noisy_realizations[0])
            for n in range(1, n_realizations + 1):
                cumsum += noisy_realizations[n-1]
                progressive_avg = cumsum / n
                
                # Save progressive heatmap
                heatmap_key = f'heatmap_{n:06d}'
                heatmap_group.create_dataset(
                    heatmap_key, 
                    data=progressive_avg,
                    compression='gzip',
                    compression_opts=4
                )
        
        print(f"\n✓ Saved progressive heatmaps to: {output_filepath}")
        print(f"  Total configs: {n_configs}")
        print(f"  Heatmaps per config: {n_realizations}")
        print(f"  Total heatmaps: {n_configs * n_realizations}")

# Usage
create_progressive_heatmaps(
    input_filepath='./simulated_data/ml_training_data.h5',
    output_filepath='./simulated_data/ml_training_heatmaps.h5'
)

In [None]:
import h5py
import matplotlib.pyplot as plt
import numpy as np

def plot_config_heatmaps(filepath, config_idx=0, cmap='viridis'):
    """
    Plot first, middle, and last progressive heatmaps for a specific config.
    
    Parameters
    ----------
    filepath : str
        Path to heatmap HDF5 file
    config_idx : int
        Configuration index to visualize
    cmap : str
        Colormap for heatmaps
    """
    
    with h5py.File(filepath, 'r') as f:
        config_name = f'config_{config_idx:06d}'
        
        # Get parameters for this config
        params = f['parameter_configs'][config_name]
        heatmap_group = f['progressive_heatmaps'][config_name]
        
        n_heatmaps = len(heatmap_group.keys())
        mid_idx = n_heatmaps // 2
        
        # Load first, middle, last heatmaps
        heatmap_first = heatmap_group['heatmap_000001'][:]
        heatmap_mid = heatmap_group[f'heatmap_{mid_idx:06d}'][:]
        heatmap_last = heatmap_group[f'heatmap_{n_heatmaps:06d}'][:]
        
        # Get parameter values for title - handle different types
        param_strs = []
        for k, v in params.attrs.items():
            if isinstance(v, (int, np.integer)):
                param_strs.append(f'{k}={v}')
            elif isinstance(v, (float, np.floating)):
                param_strs.append(f'{k}={v:.3f}')
            elif isinstance(v, (str, bytes)):
                # Handle string/bytes
                v_str = v.decode('utf-8') if isinstance(v, bytes) else v
                param_strs.append(f'{k}={v_str}')
            else:
                param_strs.append(f'{k}={v}')
        
        param_str = ', '.join(param_strs)
        
        # Plot
        fig, axes = plt.subplots(1, 3, figsize=(180, 50))
        
        # Determine common colorbar range
        vmin = min(heatmap_first.min(), heatmap_mid.min(), heatmap_last.min())
        vmax = max(heatmap_first.max(), heatmap_mid.max(), heatmap_last.max())
        
        # First heatmap (N=1)
        im0 = axes[0].imshow(heatmap_first, aspect='auto', cmap=cmap, vmin=vmin, vmax=vmax)
        axes[0].set_title(f'N = 1 image\n(Single noisy realization)', fontsize=12)
        axes[0].set_xlabel('Energy Index')
        axes[0].set_ylabel('Time Index')
        plt.colorbar(im0, ax=axes[0])
        
        # Middle heatmap
        im1 = axes[1].imshow(heatmap_mid, aspect='auto', cmap=cmap, vmin=vmin, vmax=vmax)
        axes[1].set_title(f'N = {mid_idx} images\n(Progressive average)', fontsize=12)
        axes[1].set_xlabel('Energy Index')
        axes[1].set_ylabel('Time Index')
        plt.colorbar(im1, ax=axes[1])
        
        # Last heatmap (N=70)
        im2 = axes[2].imshow(heatmap_last, aspect='auto', cmap=cmap, vmin=vmin, vmax=vmax)
        axes[2].set_title(f'N = {n_heatmaps} images\n(Full average)', fontsize=12)
        axes[2].set_xlabel('Energy Index')
        axes[2].set_ylabel('Time Index')
        plt.colorbar(im2, ax=axes[2])
        
        fig.suptitle(f'Config {config_idx}: {param_str}', fontsize=14, fontweight='bold')
        plt.tight_layout()
        plt.show()
        
        # Print statistics
        print(f"\nConfig {config_idx} Statistics:")
        print(f"Parameters: {param_str}")
        print(f"\nHeatmap N=1:")
        print(f"  Range: [{heatmap_first.min():.2e}, {heatmap_first.max():.2e}]")
        print(f"  Mean: {heatmap_first.mean():.2e}, Std: {heatmap_first.std():.2e}")
        print(f"\nHeatmap N={mid_idx}:")
        print(f"  Range: [{heatmap_mid.min():.2e}, {heatmap_mid.max():.2e}]")
        print(f"  Mean: {heatmap_mid.mean():.2e}, Std: {heatmap_mid.std():.2e}")
        print(f"\nHeatmap N={n_heatmaps}:")
        print(f"  Range: [{heatmap_last.min():.2e}, {heatmap_last.max():.2e}]")
        print(f"  Mean: {heatmap_last.mean():.2e}, Std: {heatmap_last.std():.2e}")

# Usage - visualize different configs
plot_config_heatmaps('./simulated_data/ml_training_heatmaps.h5', config_idx=0)

In [None]:
plot_config_heatmaps('./simulated_data/ml_training_heatmaps.h5', config_idx=10)

In [None]:
import h5py
import numpy as np
from tqdm import tqdm
import os

def extract_heatmaps_subset(h5_filepath, output_dir='shot_averages', max_heatmaps=30, max_configs=None):
    """
    Extract first N progressive heatmaps from HDF5 and save in NPY format.
    
    Creates:
    - combined_all_bin_averages.npy: dict with keys (bin, n_shots) -> heatmap
    - shots_per_bin.npy: dict with keys bin -> max shots
    
    Parameters
    ----------
    h5_filepath : str
        Path to the heatmaps HDF5 file
    output_dir : str
        Directory to save the NPY files
    max_heatmaps : int
        Maximum number of heatmaps to extract (default: 30)
    max_configs : int, optional
        Maximum number of configs to process (default: None = all configs)
    """
    
    os.makedirs(output_dir, exist_ok=True)
    
    # Initialize dictionaries
    combined_all_bin_averages = {}
    shots_per_bin = {}
    
    with h5py.File(h5_filepath, 'r') as f:
        n_configs = f['metadata'].attrs['n_configs']
        n_realizations_total = f['metadata'].attrs['n_realizations_per_config']
        
        # Use subset of heatmaps
        n_heatmaps = min(max_heatmaps, n_realizations_total)
        
        # Use subset of configs
        n_configs_to_process = min(max_configs, n_configs) if max_configs else n_configs
        
        print(f"Extracting first {n_heatmaps} of {n_realizations_total} heatmaps")
        print(f"Processing {n_configs_to_process} of {n_configs} configs...")
        print(f"Total heatmaps to extract: {n_configs_to_process * n_heatmaps}")
        
        # Process each configuration (bin)
        for config_idx in tqdm(range(n_configs_to_process), desc="Extracting heatmaps"):
            config_name = f'config_{config_idx:06d}'
            heatmap_group = f['progressive_heatmaps'][config_name]
            
            # Use config_idx as the bin number
            bin_num = config_idx
            
            # Record max shots for this bin
            shots_per_bin[bin_num] = n_heatmaps
            
            # Extract first N heatmaps (no calculation needed!)
            for n in range(1, n_heatmaps + 1):
                heatmap_key = f'heatmap_{n:06d}'
                heatmap = heatmap_group[heatmap_key][:]
                
                # Store with key (bin, n_shots)
                combined_all_bin_averages[(bin_num, n)] = heatmap
    
    # Save to NPY files
    combined_path = os.path.join(output_dir, 'combined_all_bin_averages.npy')
    shots_path = os.path.join(output_dir, 'shots_per_bin.npy')
    
    np.save(combined_path, combined_all_bin_averages, allow_pickle=True)
    np.save(shots_path, shots_per_bin, allow_pickle=True)
    
    print(f"\n✅ Saved files:")
    print(f"   {combined_path}")
    print(f"   {shots_path}")
    print(f"\nSummary:")
    print(f"   Total bins (configs): {len(shots_per_bin)}")
    print(f"   Shots per bin: {n_heatmaps}")
    print(f"   Total heatmaps: {len(combined_all_bin_averages)}")
    print(f"   Heatmap shape: {list(combined_all_bin_averages.values())[0].shape}")
    
    return combined_all_bin_averages, shots_per_bin

# Usage - extract first 30 heatmaps from first 10 configs
combined_dict, shots_dict = extract_heatmaps_subset(
    h5_filepath='./simulated_data/ml_training_heatmaps.h5',
    output_dir='./simulated_data/shot_averages',
    max_heatmaps=30,
    max_configs=70  # Add this parameter
)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Load the combined dictionary
combined_averages = np.load('./simulated_data/shot_averages/combined_all_bin_averages.npy', allow_pickle=True).item()

# Load shots_per_bin to know the maximum shots per bin
shots_per_bin = np.load('./simulated_data/shot_averages/shots_per_bin.npy', allow_pickle=True).item()

# Find the shape of a single image (get first entry in the dictionary)
first_key = list(combined_averages.keys())[0]
single_image_shape = combined_averages[first_key].shape

# Get all keys, ensuring they're sorted for consistent ordering
all_keys = sorted(combined_averages.keys())  # Sorts by bin number first, then shot number
total_entries = len(all_keys)

# Create a 3D numpy array
all_data = np.zeros((total_entries, *single_image_shape), dtype=np.float32)

# Fill the array with data
for i, key in enumerate(all_keys):
    all_data[i] = combined_averages[key]

print(f"Created array with shape {all_data.shape}")

# If you want to keep track of which index corresponds to which (bin, n_shots):
index_map = {i: key for i, key in enumerate(all_keys)}

# Example: Get the index for bin 5 with 20 shots
def find_index(bin_num, n_shots, index_map):
    for idx, (b, n) in index_map.items():
        if b == bin_num and n == n_shots:
            return idx
    return None

# Example usage:
bin_num = 5
n_shots = 1
idx = find_index(bin_num, n_shots, index_map)

if idx is not None:
    plt.figure(figsize=(10, 6))
    plt.imshow(all_data[idx], aspect='auto', origin='lower', cmap='viridis')
    plt.colorbar(label='Intensity')
    plt.title(f'Bin {bin_num}, {n_shots} shots (index {idx})')
    plt.tight_layout()
    plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Load the data
combined = np.load('./simulated_data/shot_averages/combined_all_bin_averages.npy', allow_pickle=True).item()
shots_per_bin = np.load('./simulated_data/shot_averages/shots_per_bin.npy', allow_pickle=True).item()

def plot_bin_heatmaps(combined_dict, shots_per_bin_dict, bin_num=20, cmap='viridis'):
    """
    Plot first, middle, and last progressive heatmaps for a specific bin.
    
    Parameters
    ----------
    combined_dict : dict
        Dictionary with keys (bin, n_shots) -> heatmap
    shots_per_bin_dict : dict
        Dictionary with keys bin -> max shots
    bin_num : int
        Bin number to visualize
    cmap : str
        Colormap for heatmaps
    """
    
    # Check if bin exists
    if bin_num not in shots_per_bin_dict:
        print(f"❌ Bin {bin_num} not found!")
        return
    
    max_shots = shots_per_bin_dict[bin_num]
    mid_shots = max_shots // 2
    
    # Get the three heatmaps
    heatmap_first = combined_dict[(bin_num, 1)]
    heatmap_mid = combined_dict[(bin_num, mid_shots)]
    heatmap_last = combined_dict[(bin_num, max_shots)]
    
    # Determine common colorbar range
    vmin = min(heatmap_first.min(), heatmap_mid.min(), heatmap_last.min())
    vmax = max(heatmap_first.max(), heatmap_mid.max(), heatmap_last.max())
    
    # Create figure
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # First heatmap (N=1)
    im0 = axes[0].imshow(heatmap_first, aspect='auto', origin='lower', 
                         cmap=cmap, vmin=vmin, vmax=vmax)
    axes[0].set_title(f'Bin {bin_num}, N = 1 shot\n(Single noisy realization)', fontsize=12)
    axes[0].set_xlabel('Energy Index')
    axes[0].set_ylabel('Time Index')
    plt.colorbar(im0, ax=axes[0], label='Intensity')
    
    # Middle heatmap
    im1 = axes[1].imshow(heatmap_mid, aspect='auto', origin='lower', 
                         cmap=cmap, vmin=vmin, vmax=vmax)
    axes[1].set_title(f'Bin {bin_num}, N = {mid_shots} shots\n(Progressive average)', fontsize=12)
    axes[1].set_xlabel('Energy Index')
    axes[1].set_ylabel('Time Index')
    plt.colorbar(im1, ax=axes[1], label='Intensity')
    
    # Last heatmap (N=max)
    im2 = axes[2].imshow(heatmap_last, aspect='auto', origin='lower', 
                         cmap=cmap, vmin=vmin, vmax=vmax)
    axes[2].set_title(f'Bin {bin_num}, N = {max_shots} shots\n(Full average)', fontsize=12)
    axes[2].set_xlabel('Energy Index')
    axes[2].set_ylabel('Time Index')
    plt.colorbar(im2, ax=axes[2], label='Intensity')
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    print(f"\nBin {bin_num} Statistics:")
    print(f"Total shots available: {max_shots}")
    print(f"\nHeatmap N=1:")
    print(f"  Range: [{heatmap_first.min():.2e}, {heatmap_first.max():.2e}]")
    print(f"  Mean: {heatmap_first.mean():.2e}, Std: {heatmap_first.std():.2e}")
    print(f"\nHeatmap N={mid_shots}:")
    print(f"  Range: [{heatmap_mid.min():.2e}, {heatmap_mid.max():.2e}]")
    print(f"  Mean: {heatmap_mid.mean():.2e}, Std: {heatmap_mid.std():.2e}")
    print(f"\nHeatmap N={max_shots}:")
    print(f"  Range: [{heatmap_last.min():.2e}, {heatmap_last.max():.2e}]")
    print(f"  Mean: {heatmap_last.mean():.2e}, Std: {heatmap_last.std():.2e}")

# Usage - visualize bin 20
plot_bin_heatmaps(combined, shots_per_bin, bin_num=20)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Load data
data_path = "./simulated_data/shot_averages/combined_all_bin_averages.npy"
print(f"Loading data from {data_path}")
combined_dict = np.load(data_path, allow_pickle=True).item()

# Organize data by bins
bin_dict = {}
for (bin_num, shot_num), img in combined_dict.items():
    if bin_num not in bin_dict:
        bin_dict[bin_num] = []
    bin_dict[bin_num].append((shot_num, img))

# Get last image for each bin
last_images = []
bin_labels = []
for bin_num in sorted(bin_dict.keys()):
    # Sort by shot number and get the last one
    shots = sorted(bin_dict[bin_num], key=lambda x: x[0])
    last_shot_num, last_img = shots[-1]
    last_images.append(last_img)
    bin_labels.append(f"Bin {bin_num}\nShot {last_shot_num}")

print(f"Found {len(last_images)} bins")

# Plot: 7 images per row, 10 rows
n_cols = 7
n_rows = 10
fig, axes = plt.subplots(n_rows, n_cols, figsize=(20, 28))
axes = axes.flatten()

for idx in range(n_rows * n_cols):
    ax = axes[idx]
    if idx < len(last_images):
        # Plot image
        im = ax.imshow(last_images[idx], cmap='viridis', aspect='auto')
        ax.set_title(bin_labels[idx], fontsize=9)
        ax.axis('off')
        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    else:
        # Empty subplot
        ax.axis('off')

plt.tight_layout()
plt.show()