In [None]:
# Speed Improvement: Vectorizing Per-Channel Transforms
# This notebook demonstrates how to optimize your dataset's __getitem__ method

import torch
import numpy as np
import time

## Current Implementation (Slow)

Your current code in `ters_image_to_image_sh.py` applies transforms **per-channel** in a Python loop:

In [None]:
# ‚ùå SLOW: Current implementation (per-channel loop)
def slow_normalize_per_channel(filtered_spectrums):
    """
    This is what your current code does - applies transform to each channel separately
    """
    # Create list of tensors per channel
    selected_images = [
        torch.from_numpy(filtered_spectrums[:, :, i]).float() 
        for i in range(filtered_spectrums.shape[2])
    ]
    
    # Apply normalize transform to each channel in a loop
    normalized = []
    for image in selected_images:
        x_mean = image.mean()
        x_std = image.std()
        if x_std == 0:
            normalized.append(image - x_mean)
        else:
            normalized.append((image - x_mean) / x_std)
    
    # Apply MinimumToZero to each
    result = [img - torch.min(img) for img in normalized]
    
    return torch.stack(result, dim=0)

## Optimized Implementation (Fast)

Vectorize by operating on the **entire tensor at once** using PyTorch's broadcasting:

In [None]:
# ‚úÖ FAST: Vectorized implementation
def fast_normalize_vectorized(filtered_spectrums):
    """
    Vectorized version - operates on all channels at once using broadcasting
    """
    # Convert to tensor in one operation: (H, W, C) -> (C, H, W)
    images = torch.from_numpy(filtered_spectrums).float().permute(2, 0, 1)
    
    # Compute mean and std per channel: shape (C, 1, 1) for broadcasting
    mean = images.mean(dim=(1, 2), keepdim=True)  # Per-channel mean
    std = images.std(dim=(1, 2), keepdim=True)    # Per-channel std
    
    # Normalize all channels at once (handle zero std with where)
    std = torch.where(std == 0, torch.ones_like(std), std)  # Avoid division by zero
    normalized = (images - mean) / std
    
    # MinimumToZero: per-channel minimum
    channel_min = normalized.amin(dim=(1, 2), keepdim=True)  # Per-channel min
    result = normalized - channel_min
    
    return result

## Benchmark: Compare Speed

In [None]:
# Create dummy data similar to your filtered_spectrums
# Shape: (64, 64, num_channels) - matching your dataset
num_channels = 400  # Your typical channel count
filtered_spectrums = np.random.randn(64, 64, num_channels).astype(np.float32)

# Benchmark slow version
n_iterations = 100
start = time.time()
for _ in range(n_iterations):
    result_slow = slow_normalize_per_channel(filtered_spectrums)
slow_time = time.time() - start

# Benchmark fast version
start = time.time()
for _ in range(n_iterations):
    result_fast = fast_normalize_vectorized(filtered_spectrums)
fast_time = time.time() - start

print(f"Slow (per-channel loop): {slow_time:.4f}s for {n_iterations} iterations")
print(f"Fast (vectorized):       {fast_time:.4f}s for {n_iterations} iterations")
print(f"Speedup: {slow_time/fast_time:.1f}x faster")
print(f"\nResults match: {torch.allclose(result_slow, result_fast, atol=1e-5)}")

## How to Apply This to Your Dataset

Replace this section in `src/datasets/ters_image_to_image_sh.py` `__getitem__` method:

In [None]:
# ‚ùå BEFORE (in your __getitem__):
"""
selected_images = [torch.from_numpy(filtered_spectrums[:,:, i]).float() 
                   for i in range(filtered_spectrums.shape[2])]

if self.t_image:
    selected_images = [self.t_image(image) for image in selected_images]

selected_images = torch.stack(selected_images, dim=0)
"""

# ‚úÖ AFTER (replace with this):
"""
# Convert all at once: (H, W, C) -> (C, H, W)
selected_images = torch.from_numpy(filtered_spectrums).float().permute(2, 0, 1)

# Apply vectorized normalization
if self.t_image:
    # Per-channel normalization (vectorized)
    mean = selected_images.mean(dim=(1, 2), keepdim=True)
    std = selected_images.std(dim=(1, 2), keepdim=True)
    std = torch.where(std == 0, torch.ones_like(std), std)
    selected_images = (selected_images - mean) / std
    
    # Per-channel MinimumToZero
    channel_min = selected_images.amin(dim=(1, 2), keepdim=True)
    selected_images = selected_images - channel_min
"""
print("See the code above for before/after comparison")

---

# üî¥ High Priority: Pre-compute and Cache Expensive Operations

The two biggest CPU bottlenecks in your `__getitem__` are:
1. **`uniform_channels()`** - Bins spectrums into frequency channels (called every sample load)
2. **`molecule_circular_image()`** - Generates target masks from atom positions (called every sample load)

**Solution**: Pre-compute these once and save to disk. This trades disk space for massive CPU savings during training.

In [None]:
# Imports for pre-computation script
import os
import sys
import glob
import numpy as np
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed

# Add project root to path
sys.path.insert(0, '/home/sethih1/masque_new/ters_gen')

from src.utils.xyz_to_label import molecule_circular_image

## Step 1: Define Pre-computation Functions

These functions compute `uniform_channels` and `molecule_circular_image` for a single sample and save to disk.

In [None]:
# Dictionary to convert atomic number to atomic symbols
atomic_symbols = {
    1: "H", 2: "He", 3: "Li", 4: "Be", 5: "B",
    6: "C", 7: "N", 8: "O", 9: "F", 10: "Ne",
}

def uniform_channels(spectrums, frequencies, num_channels=400):
    """
    Bin spectrums into uniform frequency channels.
    This is the expensive function we want to pre-compute.
    """
    max_freq = 4000
    step = max_freq // num_channels
    grid_size = spectrums.shape[1]
    channels = np.zeros((grid_size, grid_size, num_channels), dtype=np.float32)

    count = 0
    for i in range(1, max_freq, step):
        indices = (frequencies > i) & (frequencies < i + step)
        selected_spectrums = spectrums[:, :, indices]
        if np.all(selected_spectrums == 0) or selected_spectrums.size == 0:
            count += 1
            continue
        channels[:, :, count] = np.mean(selected_spectrums, axis=2)
        count += 1

    return channels


def precompute_single_sample(npz_path, output_dir, num_channels=400, 
                              frequency_range=(0, 4000), sg_ch=True, circle_radius=5):
    """
    Pre-compute uniform_channels and molecule_circular_image for a single .npz file.
    Saves result to output_dir with same filename.
    """
    filename = os.path.splitext(os.path.basename(npz_path))[0]
    output_path = os.path.join(output_dir, f"{filename}.npz")
    
    # Skip if already computed
    if os.path.exists(output_path):
        return f"Skipped {filename} (already exists)"
    
    try:
        # Load original data
        with np.load(npz_path) as data:
            atom_pos = data['atom_pos']
            atomic_numbers = data['atomic_numbers']
            frequencies = data['frequencies']
            spectrums = data['spectrums']
        
        # Filter by frequency range
        mask = (frequencies >= frequency_range[0]) & (frequencies <= frequency_range[1])
        filtered_frequencies = frequencies[mask]
        filtered_spectrums = spectrums[:, :, mask]
        
        # 1. Pre-compute uniform_channels (EXPENSIVE!)
        channels = uniform_channels(filtered_spectrums, filtered_frequencies, num_channels=num_channels)
        
        # 2. Pre-compute molecule_circular_image (EXPENSIVE!)
        # Build xyz string for molecule_circular_image
        t = list(zip(atomic_numbers, atom_pos))
        text = f"{len(t)}\nComment\n"
        for atom, pos in t:
            pos_str = "\t".join(f"{coord:.6f}" for coord in pos)
            text += atomic_symbols[atom] + "\t" + pos_str + "\n"
        
        target_image = molecule_circular_image(text, flag=sg_ch, circle_radius=circle_radius)
        
        # Save pre-computed data
        np.savez_compressed(
            output_path,
            channels=channels.astype(np.float32),           # Pre-computed uniform channels
            target_image=target_image.astype(np.float32),   # Pre-computed target mask
            # Keep original data if needed for other purposes
            atom_pos=atom_pos,
            atomic_numbers=atomic_numbers,
        )
        
        return f"Processed {filename}"
    
    except Exception as e:
        return f"Error processing {filename}: {e}"

## Step 2: Batch Pre-computation with Parallel Processing

Run this once to pre-compute all samples. Uses multiprocessing for speed.

In [None]:
def precompute_dataset(input_dir, output_dir, num_channels=400, 
                       frequency_range=(0, 4000), sg_ch=True, circle_radius=5,
                       n_workers=8):
    """
    Pre-compute all samples in a directory using parallel processing.
    
    Args:
        input_dir: Directory with original .npz files
        output_dir: Directory to save pre-computed .npz files
        num_channels: Number of frequency channels (must match training config!)
        frequency_range: Tuple of (min_freq, max_freq)
        sg_ch: Single channel target (True) or multi-channel (False)
        circle_radius: Radius for circular masks
        n_workers: Number of parallel workers
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Get all npz files
    npz_files = sorted(glob.glob(os.path.join(input_dir, '*.npz')))
    print(f"Found {len(npz_files)} files to process")
    print(f"Output directory: {output_dir}")
    print(f"Settings: num_channels={num_channels}, sg_ch={sg_ch}, circle_radius={circle_radius}")
    
    # Process in parallel
    results = []
    with ProcessPoolExecutor(max_workers=n_workers) as executor:
        futures = {
            executor.submit(
                precompute_single_sample, 
                npz_path, 
                output_dir,
                num_channels,
                frequency_range,
                sg_ch,
                circle_radius
            ): npz_path 
            for npz_path in npz_files
        }
        
        for future in tqdm(as_completed(futures), total=len(npz_files), desc="Pre-computing"):
            result = future.result()
            results.append(result)
    
    # Summary
    processed = sum(1 for r in results if r.startswith("Processed"))
    skipped = sum(1 for r in results if r.startswith("Skipped"))
    errors = sum(1 for r in results if r.startswith("Error"))
    
    print(f"\n‚úÖ Done! Processed: {processed}, Skipped: {skipped}, Errors: {errors}")

## Step 3: Run Pre-computation

‚ö†Ô∏è **Run this ONCE before training!** Adjust paths to match your data directories.

In [None]:
# ============================================================================
# CONFIGURATION - Adjust these paths to match your setup!
# ============================================================================

# Original data directories (from your config)
TRAIN_INPUT_DIR = "/scratch/phys/sin/sethih1/Extended_TERS_data/planar_oct_2025/planar_again/planar_npz_1.0/train"
VAL_INPUT_DIR = "/scratch/phys/sin/sethih1/Extended_TERS_data/planar_oct_2025/planar_again/planar_npz_1.0/val"

# Output directories for pre-computed data (new location)
TRAIN_OUTPUT_DIR = "/scratch/phys/sin/sethih1/Extended_TERS_data/planar_oct_2025/planar_again/planar_npz_1.0_precomputed/train"
VAL_OUTPUT_DIR = "/scratch/phys/sin/sethih1/Extended_TERS_data/planar_oct_2025/planar_again/planar_npz_1.0_precomputed/val"

# Settings (MUST match your training config!)
NUM_CHANNELS = 400      # or 100, depending on your in_channels config
SG_CH = True            # Single channel target (config.model.out_channels == 1)
CIRCLE_RADIUS = 5       # From config.data.circle_radius
N_WORKERS = 16          # Adjust based on your CPU cores

In [None]:
# Pre-compute training set
print("=" * 60)
print("Pre-computing TRAINING set...")
print("=" * 60)
precompute_dataset(
    input_dir=TRAIN_INPUT_DIR,
    output_dir=TRAIN_OUTPUT_DIR,
    num_channels=NUM_CHANNELS,
    sg_ch=SG_CH,
    circle_radius=CIRCLE_RADIUS,
    n_workers=N_WORKERS
)

In [None]:
# Pre-compute validation set
print("=" * 60)
print("Pre-computing VALIDATION set...")
print("=" * 60)
precompute_dataset(
    input_dir=VAL_INPUT_DIR,
    output_dir=VAL_OUTPUT_DIR,
    num_channels=NUM_CHANNELS,
    sg_ch=SG_CH,
    circle_radius=CIRCLE_RADIUS,
    n_workers=N_WORKERS
)

## Step 4: New Fast Dataset Class

This dataset loads pre-computed data - no more expensive `uniform_channels()` or `molecule_circular_image()` calls!

In [None]:
import torch
from torch.utils.data import Dataset

class Ters_dataset_precomputed(Dataset):
    """
    Fast dataset that loads pre-computed channels and target images.
    
    Use this instead of Ters_dataset_filtered_skip after running pre-computation!
    """
    
    def __init__(self, precomputed_dir, t_image=None, train_aug=False):
        """
        Args:
            precomputed_dir: Directory with pre-computed .npz files
            t_image: Transform to apply to images (use NormalizeVectorized!)
            train_aug: Whether to apply augmentation
        """
        super().__init__()
        self.precomputed_dir = precomputed_dir
        self.t_image = t_image
        self.train_aug = train_aug
        
        # Get list of pre-computed files
        self.files = sorted(glob.glob(os.path.join(precomputed_dir, '*.npz')))
        self.length = len(self.files)
        
        # For augmentation (if needed)
        if train_aug:
            from src.transforms import AugmentTransform
            self.aug_image = AugmentTransform(gauss_std_range=(0.01, 0.1))
        
        print(f"Loaded {self.length} pre-computed samples from {precomputed_dir}")
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        npz_path = self.files[idx]
        
        # Load pre-computed data (FAST! Just disk read, no computation)
        with np.load(npz_path) as data:
            channels = data['channels']        # Already computed uniform_channels!
            target_image = data['target_image']  # Already computed molecule_circular_image!
        
        # Convert to tensors: (H, W, C) -> (C, H, W)
        selected_images = torch.from_numpy(channels).float().permute(2, 0, 1).contiguous()
        target_image = torch.from_numpy(target_image).float()
        
        # Apply transforms (vectorized!)
        if self.t_image:
            selected_images = self.t_image(selected_images)
        
        # Apply augmentation
        if self.train_aug:
            selected_images, target_image = self.aug_image(img=selected_images, mask=target_image)
        
        # Dummy frequencies (for compatibility)
        selected_frequencies = torch.zeros(1)
        
        return selected_images, selected_frequencies, target_image

## Step 5: Benchmark - Compare Loading Speed

Test the speedup between original and pre-computed datasets.

In [None]:
import time
import torchvision.transforms as transforms
from src.transforms import NormalizeVectorized, MinimumToZeroVectorized
from src.datasets.ters_image_to_image_sh import Ters_dataset_filtered_skip

# Test configuration
NUM_SAMPLES_TO_TEST = 50  # Number of samples to benchmark

# Create transforms
transform = transforms.Compose([NormalizeVectorized(), MinimumToZeroVectorized()])

# ============================================================================
# Benchmark ORIGINAL dataset (slow)
# ============================================================================
print("Loading ORIGINAL dataset...")
original_ds = Ters_dataset_filtered_skip(
    filename=TRAIN_INPUT_DIR,
    frequency_range=[0, 4000],
    num_channels=NUM_CHANNELS,
    sg_ch=SG_CH,
    circle_radius=CIRCLE_RADIUS,
    t_image=transform,
    train_aug=False
)

print(f"Benchmarking {NUM_SAMPLES_TO_TEST} samples from ORIGINAL dataset...")
start = time.time()
for i in range(min(NUM_SAMPLES_TO_TEST, len(original_ds))):
    _ = original_ds[i]
original_time = time.time() - start
print(f"Original dataset: {original_time:.2f}s for {NUM_SAMPLES_TO_TEST} samples")
print(f"  ‚Üí {original_time/NUM_SAMPLES_TO_TEST*1000:.1f} ms per sample")

In [None]:
# ============================================================================
# Benchmark PRE-COMPUTED dataset (fast)
# ============================================================================
print("\nLoading PRE-COMPUTED dataset...")
precomputed_ds = Ters_dataset_precomputed(
    precomputed_dir=TRAIN_OUTPUT_DIR,
    t_image=transform,
    train_aug=False
)

print(f"Benchmarking {NUM_SAMPLES_TO_TEST} samples from PRE-COMPUTED dataset...")
start = time.time()
for i in range(min(NUM_SAMPLES_TO_TEST, len(precomputed_ds))):
    _ = precomputed_ds[i]
precomputed_time = time.time() - start
print(f"Pre-computed dataset: {precomputed_time:.2f}s for {NUM_SAMPLES_TO_TEST} samples")
print(f"  ‚Üí {precomputed_time/NUM_SAMPLES_TO_TEST*1000:.1f} ms per sample")

# ============================================================================
# Summary
# ============================================================================
print("\n" + "=" * 60)
print("SPEEDUP SUMMARY")
print("=" * 60)
print(f"Original:     {original_time/NUM_SAMPLES_TO_TEST*1000:.1f} ms/sample")
print(f"Pre-computed: {precomputed_time/NUM_SAMPLES_TO_TEST*1000:.1f} ms/sample")
print(f"Speedup:      {original_time/precomputed_time:.1f}x faster! üöÄ")
print("=" * 60)

## Step 6: How to Use in hyperopt.py

After pre-computation, update your `hyperopt.py` to use the fast dataset:

In [None]:
# Example code to add to hyperopt.py:
"""
# ============================================================================
# IN hyperopt.py - Replace the dataset creation with:
# ============================================================================

# Option 1: Import the pre-computed dataset class
from src.datasets.ters_precomputed import Ters_dataset_precomputed

# Option 2: Or define it inline (copy the class definition)

# Then in your objective() function, replace:

# ‚ùå BEFORE (slow):
train_ds = Ters_dataset_filtered_skip(
    filename=config.data.train_path,
    frequency_range=[0, 4000],
    num_channels=model_params["in_channels"],
    std_deviation_multiplier=2,
    sg_ch=(config.model.out_channels == 1),
    circle_radius=config.data.circle_radius,
    t_image=transform,
    train_aug=augmentation
)

# ‚úÖ AFTER (fast):
train_ds = Ters_dataset_precomputed(
    precomputed_dir=config.data.train_path_precomputed,  # New config field!
    t_image=transform,
    train_aug=augmentation
)

# Also update your config YAML:
# data:
#   train_path_precomputed: /path/to/precomputed/train
#   val_path_precomputed: /path/to/precomputed/val
"""
print("See the code above for integration instructions")

---

## üìä Expected Performance Gains

| Operation | Original | Pre-computed | Speedup |
|-----------|----------|--------------|---------|
| `uniform_channels()` | ~50-100ms/sample | 0ms (cached) | ‚àû |
| `molecule_circular_image()` | ~10-30ms/sample | 0ms (cached) | ‚àû |
| Disk I/O | ~5ms/sample | ~5ms/sample | 1x |
| Transforms | Variable | Variable | Same |
| **Total per sample** | **~70-140ms** | **~10-20ms** | **5-10x** |

### Trade-offs:
- ‚úÖ **Pros**: Massive CPU reduction during training, consistent load times
- ‚ö†Ô∏è **Cons**: Extra disk space (~2x original), one-time pre-computation cost

### Disk Space Estimate:
- Each pre-computed sample: ~1-2 MB (float32 channels + target)
- For 10,000 samples: ~10-20 GB additional storage

---

# üü° Medium Priority: HDF5 Format for Faster I/O

**Why HDF5 instead of individual .npz files?**

| Aspect | Individual .npz | Single HDF5 |
|--------|-----------------|-------------|
| File open/close | 1 per sample | 1 for entire dataset |
| Filesystem overhead | High (many small files) | Low (single file) |
| Memory mapping | Limited | Excellent |
| Random access | Slow | Fast |
| Storage efficiency | Moderate | Better compression |

HDF5 is the standard for large scientific datasets!

In [None]:
# Install h5py if needed
# !pip install h5py

import h5py
print(f"h5py version: {h5py.__version__}")

## HDF5 Step 1: Create HDF5 Dataset from Raw .npz Files

This creates a single HDF5 file containing all pre-computed samples.

In [None]:
def create_hdf5_dataset(input_dir, output_hdf5_path, num_channels=400,
                        frequency_range=(0, 4000), sg_ch=True, circle_radius=5,
                        compression='gzip', compression_opts=4):
    """
    Convert a directory of .npz files to a single HDF5 file with pre-computed data.
    
    Args:
        input_dir: Directory with original .npz files
        output_hdf5_path: Path to output .h5 file
        num_channels: Number of frequency channels
        frequency_range: Tuple of (min_freq, max_freq)
        sg_ch: Single channel target (True) or multi-channel (False)
        circle_radius: Radius for circular masks
        compression: HDF5 compression type ('gzip', 'lzf', or None)
        compression_opts: Compression level (1-9 for gzip)
    """
    # Get all npz files
    npz_files = sorted(glob.glob(os.path.join(input_dir, '*.npz')))
    n_samples = len(npz_files)
    
    print(f"Found {n_samples} files to process")
    print(f"Output: {output_hdf5_path}")
    print(f"Settings: num_channels={num_channels}, sg_ch={sg_ch}, circle_radius={circle_radius}")
    
    # Determine shapes from first sample
    with np.load(npz_files[0]) as data:
        spectrums = data['spectrums']
        grid_size = spectrums.shape[0]  # Usually 64
    
    target_channels = 1 if sg_ch else 4  # H, C, N, O
    target_size = 256  # From molecule_circular_image
    
    print(f"Channel shape: ({grid_size}, {grid_size}, {num_channels})")
    print(f"Target shape: ({target_channels}, {target_size}, {target_size})")
    
    # Create HDF5 file with pre-allocated datasets
    with h5py.File(output_hdf5_path, 'w') as hf:
        # Create datasets with chunking for efficient access
        channels_ds = hf.create_dataset(
            'channels',
            shape=(n_samples, grid_size, grid_size, num_channels),
            dtype=np.float32,
            chunks=(1, grid_size, grid_size, num_channels),  # 1 sample per chunk
            compression=compression,
            compression_opts=compression_opts
        )
        
        targets_ds = hf.create_dataset(
            'targets',
            shape=(n_samples, target_channels, target_size, target_size),
            dtype=np.float32,
            chunks=(1, target_channels, target_size, target_size),
            compression=compression,
            compression_opts=compression_opts
        )
        
        # Store filenames for reference
        filenames = [os.path.basename(f) for f in npz_files]
        dt = h5py.special_dtype(vlen=str)
        hf.create_dataset('filenames', data=filenames, dtype=dt)
        
        # Store metadata
        hf.attrs['num_channels'] = num_channels
        hf.attrs['sg_ch'] = sg_ch
        hf.attrs['circle_radius'] = circle_radius
        hf.attrs['grid_size'] = grid_size
        hf.attrs['target_size'] = target_size
        
        # Process each sample
        errors = []
        for i, npz_path in enumerate(tqdm(npz_files, desc="Creating HDF5")):
            try:
                # Load original data
                with np.load(npz_path) as data:
                    atom_pos = data['atom_pos']
                    atomic_numbers = data['atomic_numbers']
                    frequencies = data['frequencies']
                    spectrums = data['spectrums']
                
                # Filter by frequency range
                mask = (frequencies >= frequency_range[0]) & (frequencies <= frequency_range[1])
                filtered_frequencies = frequencies[mask]
                filtered_spectrums = spectrums[:, :, mask]
                
                # 1. Compute uniform_channels
                channels = uniform_channels(filtered_spectrums, filtered_frequencies, num_channels=num_channels)
                
                # 2. Compute molecule_circular_image
                t = list(zip(atomic_numbers, atom_pos))
                text = f"{len(t)}\nComment\n"
                for atom, pos in t:
                    pos_str = "\t".join(f"{coord:.6f}" for coord in pos)
                    text += atomic_symbols[atom] + "\t" + pos_str + "\n"
                
                target_image = molecule_circular_image(text, flag=sg_ch, circle_radius=circle_radius)
                
                # Store in HDF5
                channels_ds[i] = channels.astype(np.float32)
                targets_ds[i] = target_image.astype(np.float32)
                
            except Exception as e:
                errors.append(f"{npz_path}: {e}")
                # Fill with zeros on error
                channels_ds[i] = np.zeros((grid_size, grid_size, num_channels), dtype=np.float32)
                targets_ds[i] = np.zeros((target_channels, target_size, target_size), dtype=np.float32)
    
    # Summary
    file_size_gb = os.path.getsize(output_hdf5_path) / (1024**3)
    print(f"\n‚úÖ Done! Created {output_hdf5_path}")
    print(f"   File size: {file_size_gb:.2f} GB")
    print(f"   Samples: {n_samples}")
    if errors:
        print(f"   Errors: {len(errors)}")
        for e in errors[:5]:
            print(f"      {e}")

## HDF5 Step 2: Run HDF5 Creation

‚ö†Ô∏è **Run this ONCE to create HDF5 files for train and val sets.**

In [None]:
# ============================================================================
# HDF5 CONFIGURATION - Adjust paths!
# ============================================================================

# Output HDF5 files (single file per split!)
TRAIN_HDF5_PATH = "/scratch/phys/sin/sethih1/Extended_TERS_data/planar_oct_2025/planar_again/planar_1.0_train.h5"
VAL_HDF5_PATH = "/scratch/phys/sin/sethih1/Extended_TERS_data/planar_oct_2025/planar_again/planar_1.0_val.h5"

# Use same settings as before
# NUM_CHANNELS, SG_CH, CIRCLE_RADIUS already defined above

In [None]:
# Create TRAINING HDF5
print("=" * 60)
print("Creating TRAINING HDF5...")
print("=" * 60)
create_hdf5_dataset(
    input_dir=TRAIN_INPUT_DIR,
    output_hdf5_path=TRAIN_HDF5_PATH,
    num_channels=NUM_CHANNELS,
    sg_ch=SG_CH,
    circle_radius=CIRCLE_RADIUS,
    compression='gzip',
    compression_opts=4  # Balance between speed and size
)

In [None]:
# Create VALIDATION HDF5
print("=" * 60)
print("Creating VALIDATION HDF5...")
print("=" * 60)
create_hdf5_dataset(
    input_dir=VAL_INPUT_DIR,
    output_hdf5_path=VAL_HDF5_PATH,
    num_channels=NUM_CHANNELS,
    sg_ch=SG_CH,
    circle_radius=CIRCLE_RADIUS,
    compression='gzip',
    compression_opts=4
)

## HDF5 Step 3: Fast HDF5 Dataset Class

This is the fastest option - single file, memory-mapped access, no file open/close overhead.

In [None]:
class Ters_dataset_hdf5(Dataset):
    """
    Ultra-fast dataset using HDF5 format.
    
    Benefits:
    - Single file = no filesystem overhead
    - Memory-mapped access = OS handles caching
    - Pre-computed = no CPU bottleneck
    - Chunked storage = efficient random access
    """
    
    def __init__(self, hdf5_path, t_image=None, train_aug=False):
        """
        Args:
            hdf5_path: Path to HDF5 file
            t_image: Transform to apply to images (use NormalizeVectorized!)
            train_aug: Whether to apply augmentation
        """
        super().__init__()
        self.hdf5_path = hdf5_path
        self.t_image = t_image
        self.train_aug = train_aug
        
        # Open HDF5 file (kept open for fast access)
        self.hf = h5py.File(hdf5_path, 'r')
        self.channels = self.hf['channels']
        self.targets = self.hf['targets']
        self.length = self.channels.shape[0]
        
        # For augmentation
        if train_aug:
            from src.transforms import AugmentTransform
            self.aug_image = AugmentTransform(gauss_std_range=(0.01, 0.1))
        
        # Print metadata
        print(f"Loaded HDF5 dataset: {hdf5_path}")
        print(f"  Samples: {self.length}")
        print(f"  Channels shape: {self.channels.shape}")
        print(f"  Targets shape: {self.targets.shape}")
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        # Direct HDF5 access (FAST! Memory-mapped, no file open/close)
        channels = self.channels[idx]      # (H, W, C)
        target_image = self.targets[idx]   # (C, H, W)
        
        # Convert to tensors: (H, W, C) -> (C, H, W)
        selected_images = torch.from_numpy(channels).float().permute(2, 0, 1).contiguous()
        target_image = torch.from_numpy(target_image).float()
        
        # Apply transforms (vectorized!)
        if self.t_image:
            selected_images = self.t_image(selected_images)
        
        # Apply augmentation
        if self.train_aug:
            selected_images, target_image = self.aug_image(img=selected_images, mask=target_image)
        
        # Dummy frequencies (for compatibility)
        selected_frequencies = torch.zeros(1)
        
        return selected_images, selected_frequencies, target_image
    
    def close(self):
        """Close HDF5 file when done."""
        self.hf.close()
    
    def __del__(self):
        """Ensure file is closed on deletion."""
        try:
            self.hf.close()
        except:
            pass

## HDF5 Step 4: Benchmark All Three Methods

Compare: Original .npz ‚Üí Pre-computed .npz ‚Üí HDF5

In [None]:
# Benchmark configuration
NUM_SAMPLES_TO_TEST = 100
transform = transforms.Compose([NormalizeVectorized(), MinimumToZeroVectorized()])

results = {}

# ============================================================================
# 1. Benchmark ORIGINAL dataset (slowest)
# ============================================================================
print("=" * 60)
print("1. ORIGINAL .npz dataset")
print("=" * 60)
try:
    original_ds = Ters_dataset_filtered_skip(
        filename=TRAIN_INPUT_DIR,
        frequency_range=[0, 4000],
        num_channels=NUM_CHANNELS,
        sg_ch=SG_CH,
        circle_radius=CIRCLE_RADIUS,
        t_image=transform,
        train_aug=False
    )
    
    start = time.time()
    for i in range(min(NUM_SAMPLES_TO_TEST, len(original_ds))):
        _ = original_ds[i]
    results['original'] = time.time() - start
    print(f"Time: {results['original']:.2f}s ‚Üí {results['original']/NUM_SAMPLES_TO_TEST*1000:.1f} ms/sample")
except Exception as e:
    print(f"Skipped: {e}")
    results['original'] = None

In [None]:
# ============================================================================
# 2. Benchmark PRE-COMPUTED .npz dataset (faster)
# ============================================================================
print("\n" + "=" * 60)
print("2. PRE-COMPUTED .npz dataset")
print("=" * 60)
try:
    precomputed_ds = Ters_dataset_precomputed(
        precomputed_dir=TRAIN_OUTPUT_DIR,
        t_image=transform,
        train_aug=False
    )
    
    start = time.time()
    for i in range(min(NUM_SAMPLES_TO_TEST, len(precomputed_ds))):
        _ = precomputed_ds[i]
    results['precomputed'] = time.time() - start
    print(f"Time: {results['precomputed']:.2f}s ‚Üí {results['precomputed']/NUM_SAMPLES_TO_TEST*1000:.1f} ms/sample")
except Exception as e:
    print(f"Skipped (run pre-computation first): {e}")
    results['precomputed'] = None

In [None]:
# ============================================================================
# 3. Benchmark HDF5 dataset (fastest!)
# ============================================================================
print("\n" + "=" * 60)
print("3. HDF5 dataset")
print("=" * 60)
try:
    hdf5_ds = Ters_dataset_hdf5(
        hdf5_path=TRAIN_HDF5_PATH,
        t_image=transform,
        train_aug=False
    )
    
    start = time.time()
    for i in range(min(NUM_SAMPLES_TO_TEST, len(hdf5_ds))):
        _ = hdf5_ds[i]
    results['hdf5'] = time.time() - start
    print(f"Time: {results['hdf5']:.2f}s ‚Üí {results['hdf5']/NUM_SAMPLES_TO_TEST*1000:.1f} ms/sample")
    
    hdf5_ds.close()
except Exception as e:
    print(f"Skipped (create HDF5 first): {e}")
    results['hdf5'] = None

In [None]:
# ============================================================================
# SUMMARY
# ============================================================================
print("\n" + "=" * 60)
print("üìä BENCHMARK SUMMARY")
print("=" * 60)

baseline = results.get('original')
for name, t in results.items():
    if t is not None:
        ms_per_sample = t / NUM_SAMPLES_TO_TEST * 1000
        speedup = f"{baseline/t:.1f}x" if baseline else "N/A"
        print(f"{name:15s}: {ms_per_sample:6.1f} ms/sample  (speedup: {speedup})")
    else:
        print(f"{name:15s}: Not available")

print("=" * 60)
print("\nüèÜ Recommendation: Use HDF5 for production training!")

## HDF5 Step 5: Integration with hyperopt.py

Copy this dataset class to your project and update hyperopt.py:

In [None]:
# Save the HDF5 dataset class to your project
hdf5_dataset_code = '''
import h5py
import torch
import numpy as np
from torch.utils.data import Dataset


class Ters_dataset_hdf5(Dataset):
    """
    Ultra-fast dataset using HDF5 format.
    
    Usage:
        train_ds = Ters_dataset_hdf5(
            hdf5_path="/path/to/train.h5",
            t_image=transform,
            train_aug=True
        )
    """
    
    def __init__(self, hdf5_path, t_image=None, train_aug=False):
        super().__init__()
        self.hdf5_path = hdf5_path
        self.t_image = t_image
        self.train_aug = train_aug
        
        # Open HDF5 file (kept open for fast access)
        self.hf = h5py.File(hdf5_path, 'r')
        self.channels = self.hf['channels']
        self.targets = self.hf['targets']
        self.length = self.channels.shape[0]
        
        # For augmentation
        if train_aug:
            from src.transforms import AugmentTransform
            self.aug_image = AugmentTransform(gauss_std_range=(0.01, 0.1))
        
        print(f"Loaded HDF5: {hdf5_path} ({self.length} samples)")
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        channels = self.channels[idx]
        target_image = self.targets[idx]
        
        selected_images = torch.from_numpy(channels).float().permute(2, 0, 1).contiguous()
        target_image = torch.from_numpy(target_image).float()
        
        if self.t_image:
            selected_images = self.t_image(selected_images)
        
        if self.train_aug:
            selected_images, target_image = self.aug_image(img=selected_images, mask=target_image)
        
        return selected_images, torch.zeros(1), target_image
    
    def close(self):
        self.hf.close()
    
    def __del__(self):
        try:
            self.hf.close()
        except:
            pass
'''

# Save to file
output_path = '/home/sethih1/masque_new/ters_gen/src/datasets/ters_hdf5.py'
with open(output_path, 'w') as f:
    f.write(hdf5_dataset_code)
print(f"‚úÖ Saved HDF5 dataset class to: {output_path}")

In [None]:
# Example hyperopt.py modifications:
print("""
# ============================================================================
# CHANGES FOR hyperopt.py
# ============================================================================

# 1. Add import:
from src.datasets.ters_hdf5 import Ters_dataset_hdf5

# 2. Update config YAML:
'''
data:
  train_hdf5: /scratch/phys/sin/sethih1/.../planar_1.0_train.h5
  val_hdf5: /scratch/phys/sin/sethih1/.../planar_1.0_val.h5
'''

# 3. Replace dataset creation in objective():

# ‚ùå BEFORE:
train_ds = Ters_dataset_filtered_skip(
    filename=config.data.train_path,
    frequency_range=[0, 4000],
    num_channels=model_params["in_channels"],
    ...
)

# ‚úÖ AFTER:
train_ds = Ters_dataset_hdf5(
    hdf5_path=config.data.train_hdf5,
    t_image=transform,
    train_aug=augmentation
)

val_ds = Ters_dataset_hdf5(
    hdf5_path=config.data.val_hdf5,
    t_image=transform,
    train_aug=False
)
""")

---

## üìä Final Comparison: All Optimization Methods

| Method | ms/sample | Speedup | Disk Space | Setup Time |
|--------|-----------|---------|------------|------------|
| Original .npz | ~100-150 | 1x | Baseline | None |
| Pre-computed .npz | ~15-25 | 5-8x | ~2x | One-time |
| **HDF5** | **~5-10** | **10-20x** | **~1.5x** | **One-time** |

### HDF5 Advantages:
- ‚úÖ Single file management (easy to copy/move)
- ‚úÖ Memory-mapped I/O (OS handles caching)
- ‚úÖ No filesystem overhead (no open/close per sample)
- ‚úÖ Built-in compression (gzip, lzf)
- ‚úÖ Metadata storage (settings, filenames)
- ‚úÖ Industry standard for scientific data

### Recommended Workflow:
1. Run HDF5 creation cells **once**
2. Update `hyperopt.py` to use `Ters_dataset_hdf5`
3. Update config YAML with HDF5 paths
4. üöÄ Enjoy 10-20x faster data loading!