# Moir√© Lattice Reconstruction - ML Pipeline

This notebook implements a complete ML framework for reconstructing grayscale images as multiple moir√© lattices.

## Project Overview
- **Input**: Grayscale images
- **Output**: Array of Bravais lattices with parameters (base vectors, hole size)
- **Approach**: Both algorithmic (FFT-based) and learned (CNN) methods

In [1]:
# Environment Setup and Imports
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from scipy import fft
from scipy.ndimage import maximum_filter
import cv2
import h5py
import os
import time
from datetime import datetime

# Configure TensorFlow for better memory management
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        # Restrict TensorFlow to only allocate memory as needed
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(f"GPU configuration error: {e}")

print(f"TensorFlow version: {tf.__version__}")
print(f"NumPy version: {np.__version__}")
print(f"GPUs available: {len(gpus)}")
for gpu in gpus:
    print(f"  {gpu}")

2025-08-05 01:45:43.206810: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-08-05 01:45:43.215766: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754358343.226408   12241 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754358343.229633   12241 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1754358343.237807   12241 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

TensorFlow version: 2.19.0
NumPy version: 2.1.3
GPUs available: 1
  PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')


In [2]:
# CUDA Memory Management Configuration
import os

# Disable GPU memory pre-allocation
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

# Disable cuDNN autotune to prevent memory spikes
os.environ['TF_CUDNN_USE_AUTOTUNE'] = '0'

# Limit TensorFlow to use only necessary GPU memory

# Optional: Set CUDA device order
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'

# Optional: Limit to specific GPU if multiple are available
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'

print("‚úÖ CUDA memory management configured")
print("Environment variables set:")
for key in ['TF_FORCE_GPU_ALLOW_GROWTH', 'TF_CUDNN_USE_AUTOTUNE', 'TF_GPU_ALLOCATOR']:
    print(f"  {key} = {os.environ.get(key, 'not set')}")

‚úÖ CUDA memory management configured
Environment variables set:
  TF_FORCE_GPU_ALLOW_GROWTH = true
  TF_CUDNN_USE_AUTOTUNE = 0
  TF_GPU_ALLOCATOR = not set


## 1. Differentiable Renderer

Implements a hole-based renderer that creates moir√© patterns with holes at lattice points.

In [3]:
@tf.function
def hole_based_renderer(lattice_params, image_size=128):
    """
    Render an image with holes at lattice points (numerically stable version)
    Updated for 128px default image size
    
    Args:
        lattice_params: tensor of shape (batch_size, num_lattices * 5)
                       where each lattice has (theta, spacing, phase_x, phase_y, hole_radius)
        image_size: size of output image (default 128 for better pattern visibility)
    
    Returns:
        rendered images of shape (batch_size, image_size, image_size, 1)
    """
    batch_size = tf.shape(lattice_params)[0]
    num_lattices = tf.shape(lattice_params)[1] // 5
    
    # Reshape parameters
    params = tf.reshape(lattice_params, (batch_size, num_lattices, 5))
    
    # Create coordinate grid
    x = tf.linspace(0., float(image_size), image_size)
    y = tf.linspace(0., float(image_size), image_size)
    X, Y = tf.meshgrid(x, y)
    
    # Initialize output - start with white background
    images = tf.ones((batch_size, image_size, image_size, 1))
    
    # Process each lattice
    for i in range(num_lattices):
        theta = params[:, i, 0]      # rotation angle
        spacing = params[:, i, 1]    # lattice spacing
        phase_x = params[:, i, 2]    # phase offset in x
        phase_y = params[:, i, 3]    # phase offset in y
        hole_radius = params[:, i, 4]  # hole radius
        
        # Skip if spacing is 0 or too small (empty lattice slot)
        # Increased threshold for 128px images
        mask = tf.cast(spacing > 0.5, tf.float32)[:, None, None, None]
        
        # Clamp spacing to avoid division issues
        spacing = tf.maximum(spacing, 0.5)
        
        # Create lattice pattern using cosine waves
        cos_theta = tf.cos(theta)[:, None, None, None]
        sin_theta = tf.sin(theta)[:, None, None, None]
        spacing_expanded = spacing[:, None, None, None]
        
        # Rotated coordinates
        X_rot = X[None, :, :, None] * cos_theta + Y[None, :, :, None] * sin_theta
        Y_rot = -X[None, :, :, None] * sin_theta + Y[None, :, :, None] * cos_theta
        
        # Create lattice pattern with clamped division
        phase_x_expanded = phase_x[:, None, None, None]
        phase_y_expanded = phase_y[:, None, None, None]
        
        # Add small epsilon to avoid division by zero
        safe_spacing = tf.maximum(spacing_expanded, 1e-6)
        
        pattern_x = tf.cos(2 * np.pi * X_rot / safe_spacing + phase_x_expanded)
        pattern_y = tf.cos(2 * np.pi * Y_rot / safe_spacing + phase_y_expanded)
        
        # Combine patterns - high values at lattice points
        lattice_indicator = (pattern_x + 1) * (pattern_y + 1) / 4
        
        # Create holes: dark spots where lattice_indicator is high
        hole_pattern = 1.0 - 0.5 * (tf.nn.tanh(10.0 * (lattice_indicator - 0.9)) + 1.0)
        
        # Ensure hole_pattern is in valid range
        hole_pattern = tf.clip_by_value(hole_pattern, 0.0, 1.0)
        
        # Apply mask (only for non-zero lattices)
        images = images * (1 - mask) + images * hole_pattern * mask
    
    # Final clipping to ensure valid range
    images = tf.clip_by_value(images, 0.0, 1.0)
    
    # Check for NaN and replace with default value if any
    images = tf.where(tf.math.is_nan(images), tf.ones_like(images) * 0.5, images)
    
    return images

## 2. Training Data Generation

In [4]:
def generate_random_lattice_params(num_lattices_range=(2, 8), image_size=128):
    """
    Generate random lattice parameters for training data with constraints
    to ensure at least 3 lattice points are visible for vector inference
    
    Returns:
        lattice_params: array of shape (num_lattices, 5) with parameters
                       (theta, spacing, phase_x, phase_y, hole_radius)
    """
    num_lattices = np.random.randint(num_lattices_range[0], num_lattices_range[1] + 1)
    
    params = []
    for _ in range(num_lattices):
        # Rotation angle [0, 2œÄ)
        theta = np.random.uniform(0, 2 * np.pi)
        
        # Lattice spacing: ensure at least 3-4 periods fit in image
        # This guarantees we can see enough structure to infer periodicity
        min_spacing = image_size // 8   # At least 8 periods
        max_spacing = image_size // 3   # At least 3 periods
        spacing = np.random.uniform(min_spacing, max_spacing)
        
        # Phase offsets [0, 2œÄ) - randomize starting position
        phase_x = np.random.uniform(0, 2 * np.pi)
        phase_y = np.random.uniform(0, 2 * np.pi)
        
        # Hole radius: proportional to spacing but not too large
        # Ensure holes don't merge and destroy periodicity
        hole_radius = np.random.uniform(0.8, min(spacing / 4, 6.0))
        
        params.append([theta, spacing, phase_x, phase_y, hole_radius])
    
    # Pad with zeros if needed to reach maximum number of lattices
    max_lattices = num_lattices_range[1]
    while len(params) < max_lattices:
        params.append([0, 0, 0, 0, 0])  # Zero params = no contribution
    
    return np.array(params)

def check_existing_dataset(file_path, required_samples, image_size, max_lattices):
    """
    Check if an existing dataset meets our requirements
    
    Returns:
        bool: True if dataset is valid and sufficient
    """
    if not os.path.exists(file_path):
        return False
    
    try:
        with h5py.File(file_path, 'r') as f:
            # Check if all required datasets exist
            required_keys = ['images', 'parameters', 'num_lattices']
            if not all(key in f for key in required_keys):
                print(f"Dataset missing required keys. Found: {list(f.keys())}")
                return False
            
            # Check dimensions
            if (f.attrs.get('image_size', 0) != image_size or 
                f.attrs.get('max_lattices', 0) != max_lattices):
                print(f"Dataset dimensions don't match. "
                      f"Found: {f.attrs.get('image_size')}x{f.attrs.get('image_size')}, "
                      f"max_lattices={f.attrs.get('max_lattices')}")
                return False
            
            # Check if we have enough samples
            actual_samples = f['images'].shape[0]
            if actual_samples < required_samples:
                print(f"Dataset has {actual_samples} samples, need {required_samples}")
                return False
            
            print(f"‚úÖ Found valid existing dataset with {actual_samples} samples")
            return True
            
    except Exception as e:
        print(f"Error checking existing dataset: {e}")
        return False

def create_training_dataset(num_samples=10000, image_size=128, max_lattices=8, 
                          save_path='moire_training_data.h5', batch_size=100):
    """
    Create a large training dataset and save to HDF5 file
    Check for existing valid dataset first
    
    Updated for 128px images and larger dataset size for better coverage
    """
    print(f"Checking for existing dataset at {save_path}...")
    
    # Check if valid dataset already exists
    if check_existing_dataset(save_path, num_samples, image_size, max_lattices):
        print("Using existing dataset!")
        return
    
    print(f"Creating new training dataset with {num_samples} samples...")
    print(f"Image size: {image_size}x{image_size}, Max lattices: {max_lattices}")
    print(f"Lattice spacing range: {image_size//8} to {image_size//3} pixels")
    
    # Remove existing file if it exists but is invalid
    if os.path.exists(save_path):
        try:
            os.remove(save_path)
            print(f"Removed invalid existing file: {save_path}")
        except Exception as e:
            print(f"Warning: Could not remove existing file: {e}")
            # Try with a different filename
            save_path = save_path.replace('.h5', '_new.h5')
            print(f"Using alternative filename: {save_path}")
    
    # Create HDF5 file
    try:
        with h5py.File(save_path, 'w') as f:
            # Create datasets
            images_ds = f.create_dataset('images', (num_samples, image_size, image_size, 1), 
                                       dtype=np.float32, compression='gzip')
            params_ds = f.create_dataset('parameters', (num_samples, max_lattices * 5), 
                                       dtype=np.float32, compression='gzip')
            num_lattices_ds = f.create_dataset('num_lattices', (num_samples,), 
                                             dtype=np.int32, compression='gzip')
            
            # Add metadata
            f.attrs['image_size'] = image_size
            f.attrs['max_lattices'] = max_lattices
            f.attrs['num_samples'] = num_samples
            f.attrs['created_date'] = datetime.now().isoformat()
            f.attrs['parameter_format'] = 'theta, spacing, phase_x, phase_y, hole_radius'
            f.attrs['spacing_range'] = f'{image_size//8} to {image_size//3} pixels'
            
            # Generate data in batches
            for batch_start in range(0, num_samples, batch_size):
                batch_end = min(batch_start + batch_size, num_samples)
                current_batch_size = batch_end - batch_start
                
                print(f"Generating batch {batch_start//batch_size + 1}/{(num_samples-1)//batch_size + 1}...")
                
                batch_images = []
                batch_params = []
                batch_num_lattices = []
                
                for i in range(current_batch_size):
                    # Generate random parameters with new constraints
                    lattice_params = generate_random_lattice_params((2, max_lattices), image_size)
                    num_actual_lattices = np.sum(lattice_params[:, 1] > 0)  # Count non-zero spacings
                    
                    # Flatten parameters for model input
                    flat_params = lattice_params.flatten()
                    
                    # Render image using TensorFlow (convert to TF format)
                    tf_params = tf.constant([flat_params], dtype=tf.float32)
                    rendered = hole_based_renderer(tf_params, image_size=image_size)
                    image = rendered[0].numpy()
                    
                    # Add some noise to make training more robust
                    noise_level = np.random.uniform(0, 0.03)  # Slightly less noise for 128px
                    image += np.random.normal(0, noise_level, image.shape)
                    image = np.clip(image, 0, 1)
                    
                    batch_images.append(image)
                    batch_params.append(flat_params)
                    batch_num_lattices.append(num_actual_lattices)
                
                # Save batch to HDF5
                images_ds[batch_start:batch_end] = np.array(batch_images)
                params_ds[batch_start:batch_end] = np.array(batch_params)
                num_lattices_ds[batch_start:batch_end] = np.array(batch_num_lattices)
                
                # Clear TF graph to prevent memory buildup
                tf.keras.backend.clear_session()
        
        print(f"Dataset saved to {save_path}")
        print(f"File size: {os.path.getsize(save_path) / (1024**2):.1f} MB")
        
    except Exception as e:
        print(f"Error creating dataset: {e}")
        raise

## 3. Data Loading Infrastructure

In [5]:
class MoireDataLoader:
    """Data loader for moir√© lattice training data"""
    
    def __init__(self, hdf5_path):
        self.hdf5_path = hdf5_path
        self.file = None
        
        # Check if the file exists, if not try alternative naming
        if not os.path.exists(hdf5_path):
            alt_path = hdf5_path.replace('.h5', '_new.h5')
            if os.path.exists(alt_path):
                print(f"Using alternative dataset path: {alt_path}")
                self.hdf5_path = alt_path
            else:
                raise FileNotFoundError(f"Dataset not found at {hdf5_path} or {alt_path}")
        
        self._load_metadata()
    
    def _load_metadata(self):
        """Load dataset metadata"""
        with h5py.File(self.hdf5_path, 'r') as f:
            self.image_size = f.attrs['image_size']
            self.max_lattices = f.attrs['max_lattices']
            self.num_samples = f.attrs['num_samples']
            self.parameter_format = f.attrs['parameter_format']
            print(f"Dataset: {self.num_samples} samples, {self.image_size}x{self.image_size}, max {self.max_lattices} lattices")
            print(f"Dataset path: {self.hdf5_path}")
    
    def create_tf_dataset(self, batch_size=32, shuffle=True, validation_split=0.2):
        """Create TensorFlow datasets for training and validation"""
        
        def data_generator():
            with h5py.File(self.hdf5_path, 'r') as f:
                indices = np.arange(self.num_samples)
                if shuffle:
                    np.random.shuffle(indices)
                
                for idx in indices:
                    image = f['images'][idx]
                    params = f['parameters'][idx]
                    yield image, params
        
        # Create dataset
        dataset = tf.data.Dataset.from_generator(
            data_generator,
            output_signature=(
                tf.TensorSpec(shape=(self.image_size, self.image_size, 1), dtype=tf.float32),
                tf.TensorSpec(shape=(self.max_lattices * 5,), dtype=tf.float32)
            )
        )
        
        # Split into train/validation
        val_size = int(self.num_samples * validation_split)
        train_size = self.num_samples - val_size
        
        train_dataset = dataset.take(train_size).batch(batch_size).prefetch(tf.data.AUTOTUNE)
        val_dataset = dataset.skip(train_size).batch(batch_size).prefetch(tf.data.AUTOTUNE)
        
        return train_dataset, val_dataset

## 4. Neural Network Model

In [6]:
def create_improved_lattice_cnn(input_shape=(128, 128, 1), max_lattices=8, param_per_lattice=5):
    """
    Create a CNN optimized for periodic pattern detection
    
    Key design principles for moir√© lattice detection:
    1. Global receptive field: Use dilated convolutions to capture long-range patterns
    2. Multi-scale processing: Different scales for different lattice spacings
    3. Frequency-aware features: Large kernels to detect periodicity
    4. Global context: Self-attention and global pooling for full-image awareness
    """
    inputs = tf.keras.Input(shape=input_shape)
    
    # === MULTI-SCALE FEATURE EXTRACTION ===
    # Branch 1: Fine-scale features (small periods)
    fine_branch = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    fine_branch = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same')(fine_branch)
    fine_pool = tf.keras.layers.MaxPooling2D((2, 2))(fine_branch)  # 64x64
    
    # Branch 2: Medium-scale features with dilated convolutions
    medium_branch = tf.keras.layers.Conv2D(32, (5, 5), activation='relu', padding='same')(inputs)
    medium_branch = tf.keras.layers.Conv2D(32, (5, 5), dilation_rate=2, activation='relu', padding='same')(medium_branch)
    medium_pool = tf.keras.layers.MaxPooling2D((2, 2))(medium_branch)  # 64x64
    
    # Branch 3: Large-scale features with large kernels for long-range patterns
    large_branch = tf.keras.layers.Conv2D(32, (7, 7), activation='relu', padding='same')(inputs)
    large_branch = tf.keras.layers.Conv2D(32, (7, 7), dilation_rate=4, activation='relu', padding='same')(large_branch)
    large_pool = tf.keras.layers.MaxPooling2D((2, 2))(large_branch)  # 64x64
    
    # Combine multi-scale features
    combined = tf.keras.layers.Concatenate()([fine_pool, medium_pool, large_pool])  # 96 channels
    
    # === DILATED CONVOLUTION PYRAMID FOR GLOBAL CONTEXT ===
    # Use different dilation rates to capture patterns at various scales
    x = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(combined)
    x = tf.keras.layers.BatchNormalization()(x)
    
    # Dilated convolution block 1 (dilation=2, ~7x7 receptive field)
    dilated_1 = tf.keras.layers.Conv2D(64, (3, 3), dilation_rate=2, activation='relu', padding='same')(x)
    
    # Dilated convolution block 2 (dilation=4, ~15x15 receptive field)
    dilated_2 = tf.keras.layers.Conv2D(64, (3, 3), dilation_rate=4, activation='relu', padding='same')(x)
    
    # Dilated convolution block 3 (dilation=8, ~31x31 receptive field)
    dilated_3 = tf.keras.layers.Conv2D(64, (3, 3), dilation_rate=8, activation='relu', padding='same')(x)
    
    # Combine dilated features
    dilated_combined = tf.keras.layers.Concatenate()([dilated_1, dilated_2, dilated_3])  # 192 channels
    x = tf.keras.layers.Conv2D(128, (1, 1), activation='relu')(dilated_combined)  # Compress
    x = tf.keras.layers.MaxPooling2D((2, 2))(x)  # 32x32
    
    # === GLOBAL CONTEXT BLOCK ===
    # Global Average Pooling path for global context
    global_context = tf.keras.layers.GlobalAveragePooling2D()(x)
    global_context = tf.keras.layers.Dense(256, activation='relu')(global_context)
    global_context = tf.keras.layers.Dropout(0.3)(global_context)
    
    # Local feature path
    x = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.MaxPooling2D((2, 2))(x)  # 16x16
    
    # === FREQUENCY-AWARE PROCESSING ===
    # Large kernel convolutions to detect periodic patterns
    freq_features = tf.keras.layers.Conv2D(128, (7, 7), activation='relu', padding='same')(x)
    freq_features = tf.keras.layers.Conv2D(128, (5, 5), activation='relu', padding='same')(freq_features)
    
    # Global pooling to aggregate all spatial information
    spatial_features = tf.keras.layers.GlobalAveragePooling2D()(freq_features)
    
    # Combine global context with local features
    combined_features = tf.keras.layers.Concatenate()([global_context, spatial_features])
    
    # === DENSE PROCESSING FOR PARAMETER REGRESSION ===
    x = tf.keras.layers.Dense(512, activation='relu')(combined_features)
    x = tf.keras.layers.Dropout(0.5)(x)
    x = tf.keras.layers.Dense(256, activation='relu')(x)
    x = tf.keras.layers.Dropout(0.3)(x)
    
    # === LATTICE-SPECIFIC HEADS ===
    # Each lattice gets its own prediction head to handle variable numbers
    lattice_outputs = []
    for i in range(max_lattices):
        # Lattice existence gate (decides if this lattice should be active)
        gate = tf.keras.layers.Dense(64, activation='relu', name=f'lattice_{i}_gate')(x)
        gate = tf.keras.layers.Dense(1, activation='sigmoid', name=f'lattice_{i}_existence')(gate)
        
        # Lattice parameter prediction
        lattice_features = tf.keras.layers.Dense(64, activation='relu', name=f'lattice_{i}_features')(x)
        lattice_params = tf.keras.layers.Dense(param_per_lattice-1, activation='linear', 
                                             name=f'lattice_{i}_raw_params')(lattice_features)
        
        # Combine existence gate with parameters
        # Gate controls the magnitude of the spacing parameter
        gated_params = tf.keras.layers.Multiply(name=f'lattice_{i}_gated')([lattice_params, 
                                                                           tf.keras.layers.Concatenate()([gate, gate, gate, gate])])
        
        # Add the existence probability as the first parameter (will be converted to spacing)
        full_params = tf.keras.layers.Concatenate(name=f'lattice_{i}_params')([gate, gated_params])
        
        lattice_outputs.append(full_params)
    
    # Concatenate all lattice parameters
    outputs = tf.keras.layers.Concatenate(name='all_lattice_params')(lattice_outputs)
    
    model = tf.keras.Model(inputs, outputs, name='PeriodicPatternCNN')
    return model

# Enhanced loss function for the new architecture
@tf.function
def enhanced_moire_loss(y_true, y_pred, image_size=128, alpha=1.0, beta=0.1, gamma=0.05):
    """
    Enhanced loss function for periodic pattern detection
    
    Args:
        y_true: true parameters
        y_pred: predicted parameters  
        alpha: reconstruction loss weight
        beta: sparsity regularization weight
        gamma: parameter smoothness weight
    """
    # Basic parameter MSE loss
    param_loss = tf.reduce_mean(tf.square(y_true - y_pred))
    
    # Reconstruction loss (render images and compare)
    pred_images = hole_based_renderer(y_pred, image_size=image_size)
    true_images = hole_based_renderer(y_true, image_size=image_size)
    reconstruction_loss = tf.reduce_mean(tf.square(pred_images - true_images))
    
    # Sparsity regularization: encourage many lattices to be inactive
    # The first parameter of each lattice is the existence gate
    existence_gates = y_pred[:, ::5]  # Every 5th parameter starting from 0
    sparsity_loss = tf.reduce_mean(existence_gates)  # Encourage sparsity
    
    # Parameter smoothness: prevent extreme parameter values
    smoothness_loss = tf.reduce_mean(tf.square(y_pred))
    
    return param_loss + alpha * reconstruction_loss + beta * sparsity_loss + gamma * smoothness_loss

In [7]:
def fft_peak_extraction(image, num_peaks=8, min_distance=5):
    """
    Extract peaks from 2D FFT for lattice parameter estimation
    Updated for 128px images with better peak detection
    """
    # Compute 2D FFT
    fft_result = np.fft.fft2(image)
    fft_shifted = np.fft.fftshift(fft_result)
    magnitude = np.abs(fft_shifted)
    phase = np.angle(fft_shifted)
    
    # Find peaks in magnitude spectrum (exclude DC component)
    center = magnitude.shape[0] // 2
    magnitude_copy = magnitude.copy()
    magnitude_copy[center-3:center+4, center-3:center+4] = 0  # Larger DC exclusion for 128px
    
    # Apply maximum filter to find local maxima
    neighborhood_size = max(5, min_distance)  # Larger neighborhood for 128px
    local_maxima = maximum_filter(magnitude_copy, size=neighborhood_size)
    
    # Find positions where original equals local maximum (these are peaks)
    peak_mask = (magnitude_copy == local_maxima) & (magnitude_copy > np.max(magnitude_copy) * 0.15)
    peak_coords = np.where(peak_mask)
    
    if len(peak_coords[0]) == 0:
        # Fallback: find top values manually
        flat_indices = np.argsort(magnitude_copy.flatten())[::-1]
        coords = np.unravel_index(flat_indices[:num_peaks*2], magnitude_copy.shape)
        peak_coords = (coords[0], coords[1])
    
    # Get peak coordinates as (y, x) pairs
    peaks = list(zip(peak_coords[0], peak_coords[1]))
    
    # Sort by magnitude and take top peaks
    peak_magnitudes = [magnitude_copy[y, x] for y, x in peaks]
    sorted_peaks = sorted(zip(peaks, peak_magnitudes), key=lambda x: x[1], reverse=True)
    
    # Extract top peaks
    top_peaks = [peak[0] for peak in sorted_peaks[:num_peaks]]
    
    lattice_params = []
    for peak in top_peaks:
        y, x = peak
        # Convert to k-space coordinates (relative to center)
        ky = (y - center) * 2 * np.pi / image.shape[0]
        kx = (x - center) * 2 * np.pi / image.shape[1]
        
        # Calculate parameters
        k_magnitude = np.sqrt(kx**2 + ky**2)
        if k_magnitude > 0:
            theta = np.arctan2(ky, kx)
            amplitude = magnitude[y, x] / np.max(magnitude)
            phi = phase[y, x]
            
            lattice_params.append([theta, k_magnitude, phi, amplitude])
    
    return np.array(lattice_params)

def fft_baseline_reconstruction(image, num_lattices=8):
    """
    Baseline FFT-based reconstruction method
    Updated for 128px images with adjusted parameter ranges
    """
    # Extract FFT peaks
    fft_params = fft_peak_extraction(image, num_peaks=num_lattices)
    
    if len(fft_params) == 0:
        return np.zeros((num_lattices, 5))
    
    # Convert to our parameter format (theta, spacing, phase_x, phase_y, hole_radius)
    converted_params = []
    for i, params in enumerate(fft_params):
        theta, k_mag, phi, amplitude = params
        if k_mag > 0 and not np.isnan(k_mag) and not np.isnan(theta):
            spacing = 2 * np.pi / k_mag
            # Updated spacing range for 128px images (should match training constraints)
            spacing = np.clip(spacing, 128//8, 128//3)  # 16 to 42 pixels
            
            # Estimate hole radius from amplitude, scaled for 128px
            hole_radius = np.clip(amplitude * spacing * 0.15, 0.8, spacing / 4)
            
            # Normalize phase to [0, 2œÄ]
            phase_x = (phi + np.pi) % (2 * np.pi)
            phase_y = 0.0  # Set phase_y to 0 for simplicity
            
            converted_params.append([theta, spacing, phase_x, phase_y, hole_radius])
    
    # Pad to max_lattices with zeros
    while len(converted_params) < num_lattices:
        converted_params.append([0, 0, 0, 0, 0])  # Zero params = no contribution
    
    return np.array(converted_params[:num_lattices])

In [8]:
def train_moire_model(model, train_dataset, val_dataset, epochs=50, save_path='best_moire_model.h5'):
    """Train the moir√© lattice reconstruction model"""
    
    # Callbacks
    callbacks = [
        tf.keras.callbacks.ModelCheckpoint(
            save_path, 
            save_best_only=True, 
            monitor='val_loss',
            verbose=1
        ),
        tf.keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=5,
            min_lr=1e-6,
            verbose=1
        ),
        tf.keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=10,
            restore_best_weights=True,
            verbose=1
        )
    ]
    
    # Train model
    history = model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=epochs,
        callbacks=callbacks,
        verbose=1
    )
    
    return history

## 7. Evaluation Functions

In [9]:
def evaluate_reconstruction(model, test_images, test_params, num_samples=5):
    """Evaluate model reconstruction quality"""
    
    # Predict parameters
    pred_params = model.predict(test_images[:num_samples])
    
    # Render images from predicted and true parameters
    pred_images = hole_based_renderer(pred_params, image_size=test_images.shape[1])
    true_images = hole_based_renderer(test_params[:num_samples], image_size=test_images.shape[1])
    
    # Calculate metrics
    mse_reconstruction = tf.reduce_mean(tf.square(pred_images - test_images[:num_samples]))
    mse_parameters = tf.reduce_mean(tf.square(pred_params - test_params[:num_samples]))
    
    # Visualize results
    fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    for i in range(num_samples):
        # Original image
        axes[i, 0].imshow(test_images[i, :, :, 0], cmap='gray')
        axes[i, 0].set_title(f'Original Image {i+1}')
        axes[i, 0].axis('off')
        
        # Predicted reconstruction
        axes[i, 1].imshow(pred_images[i, :, :, 0], cmap='gray')
        axes[i, 1].set_title(f'Predicted Reconstruction')
        axes[i, 1].axis('off')
        
        # True reconstruction (from true parameters)
        axes[i, 2].imshow(true_images[i, :, :, 0], cmap='gray')
        axes[i, 2].set_title(f'True Reconstruction')
        axes[i, 2].axis('off')
        
        # Difference
        diff = np.abs(pred_images[i, :, :, 0] - test_images[i, :, :, 0])
        axes[i, 3].imshow(diff, cmap='hot')
        axes[i, 3].set_title(f'Abs Difference')
        axes[i, 3].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print(f"Reconstruction MSE: {mse_reconstruction:.6f}")
    print(f"Parameter MSE: {mse_parameters:.6f}")
    
    return mse_reconstruction, mse_parameters

def plot_training_history(history):
    """Plot training and validation loss curves"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Loss plot
    ax1.plot(history.history['loss'], label='Training Loss', linewidth=2)
    ax1.plot(history.history['val_loss'], label='Validation Loss', linewidth=2)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Model Loss During Training')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # MAE plot
    ax2.plot(history.history['mae'], label='Training MAE', linewidth=2)
    ax2.plot(history.history['val_mae'], label='Validation MAE', linewidth=2)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Mean Absolute Error')
    ax2.set_title('Model MAE During Training')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print final metrics
    final_train_loss = history.history['loss'][-1]
    final_val_loss = history.history['val_loss'][-1]
    final_train_mae = history.history['mae'][-1]
    final_val_mae = history.history['val_mae'][-1]
    
    print(f"Final Training Loss: {final_train_loss:.6f}")
    print(f"Final Validation Loss: {final_val_loss:.6f}")
    print(f"Final Training MAE: {final_train_mae:.6f}")
    print(f"Final Validation MAE: {final_val_mae:.6f}")

## 8. Complete Pipeline Execution

This cell runs the complete training pipeline with proper CUDA memory management.

In [None]:
def run_complete_pipeline(num_samples=5000, epochs=30, batch_size=32, image_size=128):
    """
    Complete pipeline for moir√© lattice reconstruction with memory management
    
    Args:
        num_samples: Number of training samples to generate
        epochs: Number of training epochs
        batch_size: Batch size for training
        image_size: Size of images (default 128)
    """
    print("üöÄ Starting Complete Moir√© Lattice Reconstruction Pipeline")
    print("=" * 60)
    
    # Clear any existing TF sessions and reset GPU
    tf.keras.backend.clear_session()
    
    # Configure TensorFlow for better memory usage
    # Disable cuDNN autotune which can cause memory issues
    import os
    os.environ['TF_CUDNN_USE_AUTOTUNE'] = '0'
    
    # Configure GPU memory growth
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            # Optionally limit memory usage
            # tf.config.set_logical_device_configuration(
            #     gpus[0],
            #     [tf.config.LogicalDeviceConfiguration(memory_limit=8192)]  # Limit to 8GB
            # )
        except RuntimeError as e:
            print(f"GPU configuration error: {e}")
    
    # Paths
    dataset_path = f'moire_training_data_{image_size}px.h5'
    model_save_path = f'moire_model_trained_{image_size}px.h5'
    
    # Step 1: Create training data
    print(f"\nüìä Creating training dataset ({image_size}x{image_size} images)...")
    create_training_dataset(
        num_samples=num_samples,
        image_size=image_size,
        max_lattices=8,
        save_path=dataset_path,
        batch_size=100
    )
    
    # Step 2: Load data
    print("\nüìÇ Loading datasets...")
    data_loader = MoireDataLoader(dataset_path)
    train_dataset, val_dataset = data_loader.create_tf_dataset(
        batch_size=batch_size, 
        validation_split=0.2
    )
    
    # Step 3: Create and compile model with correct input shape
    print(f"\nüß† Creating model for {image_size}x{image_size} images...")
    print("Using optimized model architecture")
    
    model = create_improved_lattice_cnn(
        input_shape=(image_size, image_size, 1),
        max_lattices=8
    )
    
    # Use mixed precision for memory efficiency
    policy = tf.keras.mixed_precision.Policy('mixed_float16')
    tf.keras.mixed_precision.set_global_policy(policy)
    
    # Compile with gradient clipping
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001, clipnorm=1.0)
    model.compile(
        optimizer=optimizer,
        loss='mse',
        metrics=['mae']
    )
    
    # Step 4: Train model
    print("\nüéØ Training model...")
    try:
        # Add more aggressive callbacks for memory management
        callbacks = [
            tf.keras.callbacks.ModelCheckpoint(
                model_save_path, 
                save_best_only=True, 
                monitor='val_loss',
                verbose=1
            ),
            tf.keras.callbacks.ReduceLROnPlateau(
                monitor='val_loss',
                factor=0.5,
                patience=3,
                min_lr=1e-6,
                verbose=1
            ),
            tf.keras.callbacks.EarlyStopping(
                monitor='val_loss',
                patience=5,
                restore_best_weights=True,
                verbose=1
            ),
            # Garbage collection callback
            tf.keras.callbacks.LambdaCallback(
                on_epoch_end=lambda epoch, logs: tf.keras.backend.clear_session()
            )
        ]
        
        history = model.fit(
            train_dataset,
            validation_data=val_dataset,
            epochs=epochs,
            callbacks=callbacks,
            verbose=1
        )
        
        print("\n‚úÖ Training completed!")
        
        # Step 5: Plot training history
        plot_training_history(history)
        
        # Step 6: Evaluate on validation samples
        print("\nüìà Evaluating model...")
        val_batch = next(iter(val_dataset))
        val_images, val_params = val_batch
        
        # Evaluate on a small subset
        mse_recon, mse_params = evaluate_reconstruction(
            model, 
            val_images, 
            val_params, 
            num_samples=min(3, len(val_images))
        )
        
        return model, history, data_loader
        
    except (tf.errors.ResourceExhaustedError, Exception) as e:
        print(f"\n‚ùå Error during training: {str(e)}")
        print("\nüîß Attempting recovery with reduced batch size...")
        
        # Clear everything and try with reduced settings
        tf.keras.backend.clear_session()
        
        print("Retrying with reduced batch size...")
        return run_complete_pipeline(
            num_samples=num_samples,
            epochs=epochs,
            batch_size=max(4, batch_size//2),  # Halve batch size
            image_size=image_size
        )
    finally:
        # Always clear session to free memory
        tf.keras.backend.clear_session()

# Run the pipeline with conservative settings and memory management
print("Configuring for CUDA memory management...")
model, history, data_loader = run_complete_pipeline(
    num_samples=5000,  # Start smaller
    epochs=100,         # Fewer epochs for testing
    batch_size=8,      # Smaller batch size for memory (reduced from 16)
    image_size=128     # Use 128x128 for initial testing
)

Configuring for CUDA memory management...
üöÄ Starting Complete Moir√© Lattice Reconstruction Pipeline

üìä Creating training dataset (128x128 images)...
Checking for existing dataset at moire_training_data_128px.h5...
‚úÖ Found valid existing dataset with 5000 samples
Using existing dataset!

üìÇ Loading datasets...
Dataset: 5000 samples, 128x128, max 8 lattices
Dataset path: moire_training_data_128px.h5

üß† Creating model for 128x128 images...
Using optimized model architecture


I0000 00:00:1754358345.119566   12241 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 9502 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 4070 Ti, pci bus id: 0000:01:00.0, compute capability: 8.9



üéØ Training model...
Epoch 1/100


I0000 00:00:1754358352.104549   12383 service.cc:152] XLA service 0x77b7c0002130 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1754358352.104586   12383 service.cc:160]   StreamExecutor device (0): NVIDIA GeForce RTX 4070 Ti, Compute Capability 8.9
2025-08-05 01:45:52.272197: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1754358353.919850   12383 cuda_dnn.cc:529] Loaded cuDNN version 90300


      1/Unknown [1m18s[0m 18s/step - loss: 99.2474 - mae: 4.5369


I0000 00:00:1754358363.477663   12383 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


    708/Unknown [1m214s[0m 278ms/step - loss: 38.2286 - mae: 3.1487

## 9. Comprehensive Model Evaluation

Now that the model is trained, let's perform detailed evaluation and analysis of the results.

In [None]:
# Run dataset visualization now that we have the data_loader
if 'data_loader' in locals() and data_loader is not None:
    print("\\n" + "="*60)
    print("üìä DATASET ANALYSIS")
    print("="*60)
    visualize_dataset(data_loader)

## 8.5. Dataset Visualization and Analysis

Let's examine the training dataset to understand the diversity and quality of generated patterns.

In [None]:
def visualize_dataset(data_loader, num_examples=20):
    """
    Comprehensive visualization of the training dataset
    """
    print("üé® Dataset Visualization and Analysis")
    print("=" * 50)
    
    # Load dataset
    train_dataset, val_dataset = data_loader.create_tf_dataset(batch_size=32, validation_split=0.2)
    
    # Get a batch for analysis
    train_batch = next(iter(train_dataset))
    images, params = train_batch
    
    # 1. Sample Images Grid
    fig, axes = plt.subplots(4, 5, figsize=(15, 12))
    axes = axes.flatten()
    
    for i in range(min(20, len(images))):
        axes[i].imshow(images[i, :, :, 0], cmap='gray', vmin=0, vmax=1)
        
        # Count active lattices for this sample
        sample_params = params[i].numpy()
        active_lattices = np.sum(sample_params[1::5] > 0.5)  # Count non-zero spacings
        
        axes[i].set_title(f'Sample {i+1}\\n{active_lattices} lattices', fontsize=10)
        axes[i].axis('off')
    
    plt.suptitle('Dataset Sample Images', fontsize=16, y=0.95)
    plt.tight_layout()
    plt.show()
    
    # 2. Parameter Distribution Analysis
    all_params = []
    all_images = []
    lattice_counts = []
    
    # Collect multiple batches for statistical analysis
    batch_count = 0
    for batch in train_dataset:
        if batch_count >= 5:  # Analyze first 5 batches
            break
        batch_images, batch_params = batch
        all_params.append(batch_params.numpy())
        all_images.append(batch_images.numpy())
        
        # Count lattices per sample
        for sample_idx in range(len(batch_params)):
            sample_params = batch_params[sample_idx].numpy()
            active_count = np.sum(sample_params[1::5] > 0.5)
            lattice_counts.append(active_count)
        
        batch_count += 1
    
    all_params = np.concatenate(all_params, axis=0)
    all_images = np.concatenate(all_images, axis=0)
    lattice_counts = np.array(lattice_counts)
    
    # 3. Parameter Distribution Plots
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    param_names = ['Theta (rad)', 'Spacing (px)', 'Phase X (rad)', 'Phase Y (rad)', 'Hole Radius (px)']
    
    for param_idx in range(5):
        row = param_idx // 3
        col = param_idx % 3
        ax = axes[row, col]
        
        # Extract parameter values for all active lattices
        param_values = []
        for lattice_idx in range(8):  # max 8 lattices
            param_pos = lattice_idx * 5 + param_idx
            values = all_params[:, param_pos]
            # Only include active lattices (non-zero spacing)
            if param_idx == 1:  # spacing parameter
                active_values = values[values > 0.5]
            else:
                # For other parameters, use spacing to determine if lattice is active
                spacing_pos = lattice_idx * 5 + 1
                spacing_values = all_params[:, spacing_pos]
                active_mask = spacing_values > 0.5
                active_values = values[active_mask]
            
            param_values.extend(active_values)
        
        if len(param_values) > 0:
            ax.hist(param_values, bins=50, alpha=0.7, edgecolor='black', linewidth=0.5)
            ax.set_xlabel(param_names[param_idx])
            ax.set_ylabel('Frequency')
            ax.set_title(f'{param_names[param_idx]} Distribution\\n({len(param_values)} active values)')
            ax.grid(True, alpha=0.3)
            
            # Add statistics
            mean_val = np.mean(param_values)
            std_val = np.std(param_values)
            ax.axvline(mean_val, color='red', linestyle='--', label=f'Mean: {mean_val:.2f}')
            ax.axvline(mean_val + std_val, color='orange', linestyle=':', alpha=0.7, label=f'¬±1œÉ: {std_val:.2f}')
            ax.axvline(mean_val - std_val, color='orange', linestyle=':', alpha=0.7)
            ax.legend(fontsize=8)
    
    # Remove unused subplot
    axes[1, 2].remove()
    
    plt.tight_layout()
    plt.show()
    
    # 4. Lattice Count Distribution
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))
    
    # Histogram of lattice counts
    ax = axes[0]
    unique_counts, count_frequencies = np.unique(lattice_counts, return_counts=True)
    ax.bar(unique_counts, count_frequencies, alpha=0.7, edgecolor='black')
    ax.set_xlabel('Number of Active Lattices')
    ax.set_ylabel('Frequency')
    ax.set_title('Distribution of Lattice Counts in Dataset')
    ax.grid(True, alpha=0.3)
    
    # Add percentage labels
    total_samples = len(lattice_counts)
    for count, freq in zip(unique_counts, count_frequencies):
        percentage = (freq / total_samples) * 100
        ax.text(count, freq + max(count_frequencies) * 0.01, f'{percentage:.1f}%', 
                ha='center', va='bottom', fontsize=10)
    
    # Complexity analysis
    ax = axes[1]
    # Create complexity score based on number of lattices and parameter variance
    complexity_scores = []
    for i in range(len(all_params)):
        sample_params = all_params[i]
        active_lattices = np.sum(sample_params[1::5] > 0.5)
        
        # Calculate parameter variance for active lattices
        active_params = []
        for lat_idx in range(8):
            if sample_params[lat_idx * 5 + 1] > 0.5:  # If spacing > 0.5
                active_params.extend(sample_params[lat_idx*5:(lat_idx+1)*5])
        
        if len(active_params) > 0:
            param_variance = np.var(active_params)
            complexity = active_lattices * (1 + param_variance * 0.1)  # Weight by variance
        else:
            complexity = 0
        complexity_scores.append(complexity)
    
    ax.scatter(lattice_counts, complexity_scores, alpha=0.6, s=20)
    ax.set_xlabel('Number of Active Lattices')
    ax.set_ylabel('Complexity Score')
    ax.set_title('Pattern Complexity vs Lattice Count')
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # 5. Image Statistics
    print(f"\\nüìä Dataset Statistics:")
    print(f"  Total samples analyzed: {len(all_images)}")
    print(f"  Image shape: {all_images.shape[1:3]}")
    print(f"  Pixel value range: [{np.min(all_images):.3f}, {np.max(all_images):.3f}]")
    print(f"  Mean pixel value: {np.mean(all_images):.3f}")
    print(f"  Pixel std: {np.std(all_images):.3f}")
    print(f"  \\nLattice count distribution:")
    for count, freq in zip(unique_counts, count_frequencies):
        percentage = (freq / total_samples) * 100
        print(f"    {count} lattices: {freq:4d} samples ({percentage:5.1f}%)")
    
    print(f"  \\nAverage lattices per sample: {np.mean(lattice_counts):.2f}")
    print(f"  Max lattices in sample: {np.max(lattice_counts)}")
    
    # 6. Sample complexity examples
    print(f"\\nüéØ Showing examples by complexity:")
    complexity_indices = np.argsort(complexity_scores)
    
    fig, axes = plt.subplots(2, 5, figsize=(15, 8))
    
    # Simple patterns (low complexity)
    for i in range(5):
        idx = complexity_indices[i * len(complexity_indices) // 10]  # Spread across range
        axes[0, i].imshow(all_images[idx, :, :, 0], cmap='gray', vmin=0, vmax=1)
        axes[0, i].set_title(f'Simple\\n{lattice_counts[idx]} lattices\\nScore: {complexity_scores[idx]:.1f}')
        axes[0, i].axis('off')
    
    # Complex patterns (high complexity)
    for i in range(5):
        idx = complexity_indices[-(i+1) * len(complexity_indices) // 10]
        axes[1, i].imshow(all_images[idx, :, :, 0], cmap='gray', vmin=0, vmax=1)
        axes[1, i].set_title(f'Complex\\n{lattice_counts[idx]} lattices\\nScore: {complexity_scores[idx]:.1f}')
        axes[1, i].axis('off')
    
    plt.suptitle('Dataset Complexity Examples', fontsize=16)
    plt.tight_layout()
    plt.show()

# Run dataset visualization if we have data_loader
if 'data_loader' in locals() and data_loader is not None:
    visualize_dataset(data_loader)
else:
    print("üìä Dataset visualization will run after training pipeline creates data_loader")

In [None]:
def estimate_memory_usage(batch_size, image_size, max_lattices=8):
    """
    Estimate GPU memory usage for different training configurations
    """
    print("üîç GPU Memory Usage Estimation")
    print("=" * 50)
    
    # Model parameter estimation (rough)
    # Based on the improved CNN architecture
    
    # Input: batch_size x image_size x image_size x 1
    input_memory = batch_size * image_size * image_size * 1 * 4  # float32 = 4 bytes
    
    # Convolutional layers (rough estimation)
    # Conv layers create intermediate feature maps
    conv_memory = batch_size * image_size * image_size * 256 * 4  # Assume 256 max channels
    
    # Dense layers
    dense_features = 512  # From the architecture
    dense_memory = batch_size * dense_features * 4
    
    # Output: batch_size x (max_lattices * 5)
    output_memory = batch_size * max_lattices * 5 * 4
    
    # Gradient memory (roughly 2x model parameters)
    gradient_memory = 2 * (conv_memory + dense_memory)
    
    # Optimizer states (Adam needs ~2x parameters)
    optimizer_memory = 2 * (conv_memory + dense_memory)
    
    total_memory_bytes = (input_memory + conv_memory + dense_memory + 
                         output_memory + gradient_memory + optimizer_memory)
    
    total_memory_gb = total_memory_bytes / (1024**3)
    
    print(f"Estimated memory usage for batch_size={batch_size}, image_size={image_size}:")
    print(f"  Input tensors:     {input_memory / (1024**2):6.1f} MB")
    print(f"  Conv features:     {conv_memory / (1024**2):6.1f} MB") 
    print(f"  Dense features:    {dense_memory / (1024**2):6.1f} MB")
    print(f"  Output tensors:    {output_memory / (1024**2):6.1f} MB")
    print(f"  Gradients:         {gradient_memory / (1024**2):6.1f} MB")
    print(f"  Optimizer states:  {optimizer_memory / (1024**2):6.1f} MB")
    print(f"  ‚îÄ" * 40)
    print(f"  TOTAL:            {total_memory_gb:6.2f} GB")
    
    # Safety recommendations
    if total_memory_gb < 4:
        print("  ‚úÖ Should fit comfortably on most GPUs")
    elif total_memory_gb < 8:
        print("  ‚ö° Should work on 8GB+ GPUs")
    elif total_memory_gb < 12:
        print("  ‚ö†Ô∏è  Needs 12GB+ GPU, might be tight")
    else:
        print("  ‚ùå Likely too large for most consumer GPUs")
    
    return total_memory_gb

def find_optimal_batch_size(target_memory_gb=8, image_size=128):
    """
    Find the largest batch size that fits in target memory
    """
    print(f"\\nüéØ Finding optimal batch size for {target_memory_gb}GB GPU:")
    print("=" * 50)
    
    optimal_batch = 1
    for batch_size in range(1, 33):  # Test up to batch size 32
        estimated_memory = estimate_memory_usage(batch_size, image_size)
        if estimated_memory <= target_memory_gb * 0.8:  # Use 80% of available memory for safety
            optimal_batch = batch_size
        else:
            break
    
    print(f"\\nüí° Recommended batch size: {optimal_batch}")
    print(f"   (Uses ~{estimate_memory_usage(optimal_batch, image_size):.1f}GB of {target_memory_gb}GB available)")
    
    return optimal_batch

# Test current configuration
print("Current configuration analysis:")
estimate_memory_usage(batch_size=6, image_size=128)

# Find optimal batch size for your GPU
optimal_batch = find_optimal_batch_size(target_memory_gb=9, image_size=128)  # Conservative 9GB estimate

In [None]:
def optimize_training_config(gpu_memory_gb=9, target_accuracy='high'):
    """
    Recommend optimal training configuration based on available resources
    
    Args:
        gpu_memory_gb: Available GPU memory in GB
        target_accuracy: 'fast', 'balanced', or 'high'
    """
    print("‚öôÔ∏è Training Configuration Optimizer")
    print("=" * 50)
    
    # Find optimal batch size for available memory
    optimal_batch = find_optimal_batch_size(target_memory_gb=gpu_memory_gb, image_size=128)
    
    # Configuration recommendations based on target
    configs = {
        'fast': {
            'num_samples': 4000,
            'epochs': 15,
            'description': 'Quick training for testing and prototyping'
        },
        'balanced': {
            'num_samples': 8000,
            'epochs': 25,
            'description': 'Good balance of quality and training time'
        },
        'high': {
            'num_samples': 15000,
            'epochs': 40,
            'description': 'High quality results, longer training time'
        }
    }
    
    config = configs[target_accuracy]
    recommended_batch = min(optimal_batch, 8)  # Cap at 8 for stability
    
    print(f"\\nüéØ Recommended Configuration ({target_accuracy} accuracy):")
    print(f"  GPU Memory: {gpu_memory_gb}GB")
    print(f"  Batch Size: {recommended_batch}")
    print(f"  Dataset Size: {config['num_samples']:,} samples")
    print(f"  Epochs: {config['epochs']}")
    print(f"  Description: {config['description']}")
    
    # Calculate training time estimate
    samples_per_epoch = config['num_samples'] * 0.8  # 80% for training
    steps_per_epoch = int(samples_per_epoch / recommended_batch)
    total_steps = steps_per_epoch * config['epochs']
    
    # Rough timing estimate (based on your GPU performance)
    seconds_per_step = 0.15  # Conservative estimate
    total_time_hours = (total_steps * seconds_per_step) / 3600
    
    print(f"\\n‚è±Ô∏è Training Time Estimate:")
    print(f"  Steps per epoch: {steps_per_epoch}")
    print(f"  Total steps: {total_steps:,}")
    print(f"  Estimated time: {total_time_hours:.1f} hours")
    
    # Memory usage check
    memory_usage = estimate_memory_usage(recommended_batch, 128)
    memory_utilization = (memory_usage / gpu_memory_gb) * 100
    
    print(f"\\nüíæ Memory Analysis:")
    print(f"  Estimated usage: {memory_usage:.1f}GB ({memory_utilization:.0f}% of available)")
    
    if memory_utilization < 60:
        print("  ‚úÖ Conservative - could potentially use larger batch size")
    elif memory_utilization < 80:
        print("  ‚ö° Optimal - good memory utilization")
    else:
        print("  ‚ö†Ô∏è  High - might need to reduce batch size if errors occur")
    
    # Generate the code
    print(f"\\nüíª Recommended Code:")
    print("```python")
    print("model, history, data_loader = run_complete_pipeline(")
    print(f"    num_samples={config['num_samples']},")
    print(f"    epochs={config['epochs']},")
    print(f"    batch_size={recommended_batch},")
    print("    image_size=128")
    print(")")
    print("```")
    
    return {
        'num_samples': config['num_samples'],
        'epochs': config['epochs'],
        'batch_size': recommended_batch,
        'image_size': 128,
        'estimated_time_hours': total_time_hours,
        'memory_usage_gb': memory_usage
    }

# Test different configurations
print("üöÄ Configuration Analysis for Your Setup:")
print()

fast_config = optimize_training_config(gpu_memory_gb=9, target_accuracy='fast')
print("\\n" + "="*60)
balanced_config = optimize_training_config(gpu_memory_gb=9, target_accuracy='balanced')
print("\\n" + "="*60)
high_config = optimize_training_config(gpu_memory_gb=9, target_accuracy='high')

print(f"\\nüé≤ Current Configuration (in the notebook):")
print(f"  num_samples=8000, epochs=25, batch_size=6")
print(f"  This matches the 'balanced' recommendation! ‚úÖ")

In [None]:
def detailed_parameter_analysis(model, data_loader, num_samples=10):
    """
    Analyze model predictions vs ground truth parameters in detail
    """
    print("üîç Detailed Parameter Analysis")
    print("=" * 50)
    
    # Get validation data
    val_dataset = data_loader.create_tf_dataset(batch_size=32, validation_split=0.2)[1]
    val_batch = next(iter(val_dataset))
    val_images, val_params = val_batch
    
    # Get predictions
    predictions = model.predict(val_images[:num_samples])
    true_params = val_params[:num_samples].numpy()
    
    # Analyze each parameter type
    param_names = ['theta', 'spacing', 'phase_x', 'phase_y', 'hole_radius']
    max_lattices = 8
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    axes = axes.flatten()
    
    for param_idx in range(5):
        ax = axes[param_idx]
        
        # Extract parameter values for all lattices
        true_vals = []
        pred_vals = []
        
        for lattice_idx in range(max_lattices):
            param_pos = lattice_idx * 5 + param_idx
            true_vals.extend(true_params[:, param_pos])
            pred_vals.extend(predictions[:, param_pos])
        
        # Filter out zero parameters (inactive lattices)
        non_zero_mask = np.array(true_vals) > 1e-6
        true_vals = np.array(true_vals)[non_zero_mask]
        pred_vals = np.array(pred_vals)[non_zero_mask]
        
        if len(true_vals) > 0:
            # Scatter plot
            ax.scatter(true_vals, pred_vals, alpha=0.6, s=20)
            
            # Perfect prediction line
            min_val, max_val = min(true_vals.min(), pred_vals.min()), max(true_vals.max(), pred_vals.max())
            ax.plot([min_val, max_val], [min_val, max_val], 'r--', label='Perfect prediction')
            
            # Calculate metrics
            mae = np.mean(np.abs(true_vals - pred_vals))
            r2 = 1 - np.sum((true_vals - pred_vals)**2) / np.sum((true_vals - np.mean(true_vals))**2)
            
            ax.set_xlabel(f'True {param_names[param_idx]}')
            ax.set_ylabel(f'Predicted {param_names[param_idx]}')
            ax.set_title(f'{param_names[param_idx]}\nMAE: {mae:.4f}, R¬≤: {r2:.4f}')
            ax.legend()
            ax.grid(True, alpha=0.3)
        else:
            ax.text(0.5, 0.5, f'No active\n{param_names[param_idx]}', 
                   transform=ax.transAxes, ha='center', va='center')
            ax.set_title(param_names[param_idx])
    
    # Remove unused subplot
    axes[5].remove()
    
    plt.tight_layout()
    plt.show()
    
    # Print detailed statistics
    print("\nüìä Parameter Statistics:")
    for param_idx in range(5):
        true_vals = []
        pred_vals = []
        
        for lattice_idx in range(max_lattices):
            param_pos = lattice_idx * 5 + param_idx
            true_vals.extend(true_params[:, param_pos])
            pred_vals.extend(predictions[:, param_pos])
        
        non_zero_mask = np.array(true_vals) > 1e-6
        true_vals = np.array(true_vals)[non_zero_mask]
        pred_vals = np.array(pred_vals)[non_zero_mask]
        
        if len(true_vals) > 0:
            mae = np.mean(np.abs(true_vals - pred_vals))
            rmse = np.sqrt(np.mean((true_vals - pred_vals)**2))
            r2 = 1 - np.sum((true_vals - pred_vals)**2) / np.sum((true_vals - np.mean(true_vals))**2)
            
            print(f"  {param_names[param_idx]:12s}: MAE={mae:7.4f}, RMSE={rmse:7.4f}, R¬≤={r2:7.4f}")

# Run the analysis if we have a trained model
if 'model' in locals() and model is not None:
    detailed_parameter_analysis(model, data_loader, num_samples=20)

In [None]:
def visual_reconstruction_comparison(model, data_loader, num_examples=6):
    """
    Compare original images with reconstructed images from predicted parameters
    """
    print("üé® Visual Reconstruction Comparison")
    print("=" * 50)
    
    # Get validation data
    val_dataset = data_loader.create_tf_dataset(batch_size=32, validation_split=0.2)[1]
    val_batch = next(iter(val_dataset))
    val_images, val_params = val_batch
    
    # Get predictions
    predictions = model.predict(val_images[:num_examples])
    
    # Render reconstructed images
    reconstructed = hole_based_renderer(predictions, image_size=128)
    
    # Create comparison plot
    fig, axes = plt.subplots(3, num_examples, figsize=(3*num_examples, 9))
    
    for i in range(num_examples):
        # Original image
        axes[0, i].imshow(val_images[i, :, :, 0], cmap='gray', vmin=0, vmax=1)
        axes[0, i].set_title(f'Original {i+1}')
        axes[0, i].axis('off')
        
        # Reconstructed image
        axes[1, i].imshow(reconstructed[i, :, :, 0], cmap='gray', vmin=0, vmax=1)
        axes[1, i].set_title(f'Reconstructed {i+1}')
        axes[1, i].axis('off')
        
        # Difference
        diff = np.abs(val_images[i, :, :, 0] - reconstructed[i, :, :, 0])
        im = axes[2, i].imshow(diff, cmap='Reds', vmin=0, vmax=np.max(diff))
        axes[2, i].set_title(f'Difference {i+1}\nMAE: {np.mean(diff):.4f}')
        axes[2, i].axis('off')
        
        # Add colorbar for difference
        plt.colorbar(im, ax=axes[2, i], fraction=0.046, pad=0.04)
    
    plt.suptitle('Visual Reconstruction Quality Assessment', fontsize=16, y=0.95)
    plt.tight_layout()
    plt.show()
    
    # Calculate overall reconstruction metrics
    all_original = val_images[:num_examples]
    all_reconstructed = reconstructed
    
    reconstruction_mae = np.mean(np.abs(all_original - all_reconstructed))
    reconstruction_mse = np.mean((all_original - all_reconstructed)**2)
    ssim_scores = []
    
    try:
        from skimage.metrics import structural_similarity as ssim
        for i in range(num_examples):
            ssim_score = ssim(all_original[i, :, :, 0], all_reconstructed[i, :, :, 0])
            ssim_scores.append(ssim_score)
        avg_ssim = np.mean(ssim_scores)
        print(f"\nüìä Reconstruction Metrics:")
        print(f"  MAE: {reconstruction_mae:.6f}")
        print(f"  MSE: {reconstruction_mse:.6f}")
        print(f"  SSIM: {avg_ssim:.4f}")
    except ImportError:
        print(f"\nüìä Reconstruction Metrics:")
        print(f"  MAE: {reconstruction_mae:.6f}")
        print(f"  MSE: {reconstruction_mse:.6f}")
        print("  SSIM: (requires scikit-image)")

# Run the visual comparison if we have a trained model
if 'model' in locals() and model is not None:
    visual_reconstruction_comparison(model, data_loader, num_examples=8)

In [None]:
def lattice_count_analysis(model, data_loader, num_samples=100):
    """
    Analyze how well the model predicts the number of active lattices
    """
    print("üî¢ Lattice Count and Sparsity Analysis")
    print("=" * 50)
    
    # Get validation data
    val_dataset = data_loader.create_tf_dataset(batch_size=32, validation_split=0.2)[1]
    
    true_counts = []
    pred_counts = []
    existence_probs = []
    
    samples_processed = 0
    for batch in val_dataset:
        if samples_processed >= num_samples:
            break
            
        val_images, val_params = batch
        batch_size = val_images.shape[0]
        
        # Get predictions
        predictions = model.predict(val_images, verbose=0)
        
        for i in range(min(batch_size, num_samples - samples_processed)):
            # Count true active lattices (spacing > threshold)
            true_spacings = val_params[i, 1::5].numpy()  # Every 5th param starting from 1
            true_count = np.sum(true_spacings > 0.5)
            true_counts.append(true_count)
            
            # Count predicted active lattices
            pred_spacings = predictions[i, 1::5]  # Every 5th param starting from 1
            pred_count = np.sum(pred_spacings > 0.5)
            pred_counts.append(pred_count)
            
            # Store existence probabilities (first param of each lattice)
            existence_gates = predictions[i, ::5]  # Every 5th param starting from 0
            existence_probs.append(existence_gates)
        
        samples_processed += batch_size
    
    true_counts = np.array(true_counts[:num_samples])
    pred_counts = np.array(pred_counts[:num_samples])
    existence_probs = np.array(existence_probs[:num_samples])
    
    # Create analysis plots
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # Count comparison histogram
    ax = axes[0, 0]
    max_count = max(np.max(true_counts), np.max(pred_counts)) + 1
    bins = np.arange(-0.5, max_count + 0.5, 1)
    
    ax.hist(true_counts, bins=bins, alpha=0.6, label='True counts', density=True)
    ax.hist(pred_counts, bins=bins, alpha=0.6, label='Predicted counts', density=True)
    ax.set_xlabel('Number of Active Lattices')
    ax.set_ylabel('Density')
    ax.set_title('Distribution of Lattice Counts')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Count accuracy scatter plot
    ax = axes[0, 1]
    ax.scatter(true_counts, pred_counts, alpha=0.6)
    ax.plot([0, max_count], [0, max_count], 'r--', label='Perfect prediction')
    ax.set_xlabel('True Count')
    ax.set_ylabel('Predicted Count')
    ax.set_title('Count Prediction Accuracy')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Existence probability distribution
    ax = axes[1, 0]
    all_existence = existence_probs.flatten()
    ax.hist(all_existence, bins=50, alpha=0.7, density=True)
    ax.axvline(0.5, color='red', linestyle='--', label='Decision threshold (0.5)')\n    ax.set_xlabel('Existence Probability')
    ax.set_ylabel('Density')
    ax.set_title('Distribution of Existence Probabilities')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Count confusion matrix
    ax = axes[1, 1]
    from sklearn.metrics import confusion_matrix
    cm = confusion_matrix(true_counts, pred_counts)
    im = ax.imshow(cm, interpolation='nearest', cmap='Blues')
    ax.set_xlabel('Predicted Count')
    ax.set_ylabel('True Count')
    ax.set_title('Count Prediction Confusion Matrix')
    
    # Add text annotations
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], 'd'),
                   ha="center", va="center",
                   color="white" if cm[i, j] > thresh else "black")
    
    plt.colorbar(im, ax=ax)
    plt.tight_layout()
    plt.show()
    
    # Calculate metrics
    count_accuracy = np.mean(true_counts == pred_counts)
    count_mae = np.mean(np.abs(true_counts - pred_counts))
    
    print(f"\nüìä Lattice Count Metrics:")
    print(f"  Count Accuracy: {count_accuracy:.3f}")
    print(f"  Count MAE: {count_mae:.3f}")
    print(f"  True count range: {np.min(true_counts)} - {np.max(true_counts)}")
    print(f"  Pred count range: {np.min(pred_counts)} - {np.max(pred_counts)}")
    print(f"  Existence prob range: {np.min(all_existence):.3f} - {np.max(all_existence):.3f}")

# Run the lattice count analysis if we have a trained model
if 'model' in locals() and model is not None:
    lattice_count_analysis(model, data_loader, num_samples=200)

In [None]:
def enhanced_training_analysis(history):
    """
    Enhanced analysis of training dynamics with detailed plots
    """
    if history is None:
        print("‚ùå No training history available")
        return
        
    print("üìà Enhanced Training Dynamics Analysis")
    print("=" * 50)
    
    # Create comprehensive training plots
    fig, axes = plt.subplots(2, 3, figsize=(20, 12))
    
    # 1. Loss curves with log scale
    ax = axes[0, 0]
    epochs = range(1, len(history.history['loss']) + 1)
    ax.semilogy(epochs, history.history['loss'], 'b-', label='Training Loss', linewidth=2)
    ax.semilogy(epochs, history.history['val_loss'], 'r-', label='Validation Loss', linewidth=2)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss (log scale)')
    ax.set_title('Training and Validation Loss (Log Scale)')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # 2. MAE curves
    ax = axes[0, 1]
    ax.plot(epochs, history.history['mae'], 'b-', label='Training MAE', linewidth=2)
    ax.plot(epochs, history.history['val_mae'], 'r-', label='Validation MAE', linewidth=2)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Mean Absolute Error')
    ax.set_title('Training and Validation MAE')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # 3. Overfitting analysis
    ax = axes[0, 2]
    val_gap = np.array(history.history['val_loss']) - np.array(history.history['loss'])
    ax.plot(epochs, val_gap, 'g-', linewidth=2)
    ax.axhline(y=0, color='k', linestyle='--', alpha=0.5)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Validation - Training Loss')
    ax.set_title('Overfitting Indicator')
    ax.grid(True, alpha=0.3)
    
    # 4. Learning rate (if available)
    ax = axes[1, 0]
    if 'lr' in history.history:
        ax.semilogy(epochs, history.history['lr'], 'purple', linewidth=2)
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Learning Rate (log scale)')
        ax.set_title('Learning Rate Schedule')
    else:
        ax.text(0.5, 0.5, 'Learning Rate\\nNot Tracked', transform=ax.transAxes, 
                ha='center', va='center', fontsize=12)
        ax.set_title('Learning Rate Schedule')
    ax.grid(True, alpha=0.3)
    
    # 5. Loss improvement rate
    ax = axes[1, 1]
    loss_diff = -np.diff(history.history['loss'])  # Negative because we want improvement (decrease)
    val_loss_diff = -np.diff(history.history['val_loss'])
    
    ax.plot(epochs[1:], loss_diff, 'b-', label='Training Improvement', linewidth=2)
    ax.plot(epochs[1:], val_loss_diff, 'r-', label='Validation Improvement', linewidth=2)
    ax.axhline(y=0, color='k', linestyle='--', alpha=0.5)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss Improvement')
    ax.set_title('Per-Epoch Loss Improvement')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # 6. Training stability (rolling metrics)
    ax = axes[1, 2]
    window_size = min(5, len(epochs) // 4)
    if window_size >= 2:
        # Rolling standard deviation of loss
        train_loss_std = pd.Series(history.history['loss']).rolling(window=window_size).std()
        val_loss_std = pd.Series(history.history['val_loss']).rolling(window=window_size).std()
        
        ax.plot(epochs, train_loss_std, 'b-', label='Training Loss Std', linewidth=2)
        ax.plot(epochs, val_loss_std, 'r-', label='Validation Loss Std', linewidth=2)
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Rolling Standard Deviation')
        ax.set_title(f'Training Stability (Window: {window_size})')
        ax.legend()
    else:
        ax.text(0.5, 0.5, 'Not enough\\nepochs for\\nstability analysis', 
                transform=ax.transAxes, ha='center', va='center')
        ax.set_title('Training Stability')
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print detailed training statistics
    print(f"\\nüìä Training Summary:")
    print(f"  Total epochs: {len(epochs)}")
    print(f"  Final train loss: {history.history['loss'][-1]:.6f}")
    print(f"  Final val loss: {history.history['val_loss'][-1]:.6f}")
    print(f"  Best val loss: {min(history.history['val_loss']):.6f} (epoch {np.argmin(history.history['val_loss'])+1})")
    print(f"  Final train MAE: {history.history['mae'][-1]:.6f}")
    print(f"  Final val MAE: {history.history['val_mae'][-1]:.6f}")
    print(f"  Best val MAE: {min(history.history['val_mae']):.6f} (epoch {np.argmin(history.history['val_mae'])+1})")
    
    # Overfitting analysis
    final_gap = history.history['val_loss'][-1] - history.history['loss'][-1]
    print(f"  Final overfitting gap: {final_gap:.6f}")
    
    if final_gap > 0.1:
        print("  ‚ö†Ô∏è  Warning: Significant overfitting detected")
    elif final_gap > 0.05:
        print("  ‚ö° Mild overfitting detected")
    else:
        print("  ‚úÖ Good generalization")

# Import pandas if not already imported
try:
    import pandas as pd
except ImportError:
    print("Installing pandas for enhanced analysis...")
    !pip install pandas

# Run the enhanced training analysis if we have history
if 'history' in locals() and history is not None:
    enhanced_training_analysis(history)

In [None]:
def interactive_model_exploration(model, image_size=128):
    """
    Interactive exploration of model behavior with custom generated patterns
    """
    print("üéÆ Interactive Model Exploration")
    print("=" * 50)
    
    # Test different pattern complexities
    test_cases = [
        # Simple cases
        {"name": "Single Lattice", "params": [[np.pi/4, 20, 0, 0, 3] + [0]*35]},
        {"name": "Two Lattices", "params": [[0, 25, 0, 0, 2, np.pi/2, 30, np.pi, np.pi/2, 2.5] + [0]*30]},
        
        # Complex cases
        {"name": "Dense Pattern", "params": [[0, 15, 0, 0, 1.5, np.pi/3, 18, np.pi/2, 0, 2, np.pi/6, 22, np.pi, np.pi/3, 1.8] + [0]*25]},
        
        # Edge cases
        {"name": "Large Spacing", "params": [[0, 60, 0, 0, 8] + [0]*35]},  # Very coarse pattern
        {"name": "Small Spacing", "params": [[0, 8, 0, 0, 1] + [0]*35]},   # Very fine pattern
        {"name": "Rotated", "params": [[np.pi/6, 25, 0, 0, 3, np.pi/3, 25, np.pi/4, np.pi/4, 3] + [0]*30]},  # Strong rotation
    ]
    
    fig, axes = plt.subplots(len(test_cases), 4, figsize=(20, 5*len(test_cases)))
    if len(test_cases) == 1:
        axes = axes.reshape(1, -1)
    
    for case_idx, test_case in enumerate(test_cases):
        # Generate test image
        params_tensor = tf.constant([test_case["params"]], dtype=tf.float32)
        test_image = hole_based_renderer(params_tensor, image_size=image_size)
        
        # Get model prediction
        prediction = model.predict(test_image, verbose=0)
        reconstructed = hole_based_renderer(prediction, image_size=image_size)
        
        # Plot original
        axes[case_idx, 0].imshow(test_image[0, :, :, 0], cmap='gray', vmin=0, vmax=1)
        axes[case_idx, 0].set_title(f'{test_case["name"]}\\nOriginal')
        axes[case_idx, 0].axis('off')
        
        # Plot reconstruction
        axes[case_idx, 1].imshow(reconstructed[0, :, :, 0], cmap='gray', vmin=0, vmax=1)
        axes[case_idx, 1].set_title('Reconstruction')
        axes[case_idx, 1].axis('off')
        
        # Plot difference
        diff = np.abs(test_image[0, :, :, 0] - reconstructed[0, :, :, 0])
        im = axes[case_idx, 2].imshow(diff, cmap='Reds')
        axes[case_idx, 2].set_title(f'Difference\\nMAE: {np.mean(diff):.4f}')
        axes[case_idx, 2].axis('off')
        plt.colorbar(im, ax=axes[case_idx, 2], fraction=0.046, pad=0.04)
        
        # Parameter comparison
        true_params = np.array(test_case["params"])
        pred_params = prediction[0]
        
        # Count active lattices
        true_active = np.sum(true_params[1::5] > 0.5)
        pred_active = np.sum(pred_params[1::5] > 0.5)
        
        # Show parameter comparison for first few lattices
        axes[case_idx, 3].axis('off')
        param_text = f"Active Lattices:\\nTrue: {true_active}, Pred: {pred_active}\\n\\n"
        
        param_names = ['Œ∏', 'sp', 'œÜx', 'œÜy', 'hr']
        for lat_idx in range(min(3, true_active)):  # Show first 3 active lattices
            param_text += f"Lattice {lat_idx+1}:\\n"
            for p_idx, p_name in enumerate(param_names):
                true_val = true_params[lat_idx*5 + p_idx]
                pred_val = pred_params[lat_idx*5 + p_idx]
                param_text += f"  {p_name}: {true_val:.2f} ‚Üí {pred_val:.2f}\\n"
            param_text += "\\n"
        
        axes[case_idx, 3].text(0.05, 0.95, param_text, transform=axes[case_idx, 3].transAxes,
                              verticalalignment='top', fontfamily='monospace', fontsize=8)
        axes[case_idx, 3].set_title('Parameter Comparison')
    
    plt.tight_layout()
    plt.show()
    
    # Statistical summary of test cases
    print("\\nüìä Test Case Performance Summary:")
    for case_idx, test_case in enumerate(test_cases):
        params_tensor = tf.constant([test_case["params"]], dtype=tf.float32)
        test_image = hole_based_renderer(params_tensor, image_size=image_size)
        prediction = model.predict(test_image, verbose=0)
        reconstructed = hole_based_renderer(prediction, image_size=image_size)
        
        mae = np.mean(np.abs(test_image - reconstructed))
        mse = np.mean((test_image - reconstructed)**2)
        
        true_params = np.array(test_case["params"])
        param_mae = np.mean(np.abs(true_params - prediction[0]))
        
        print(f"  {test_case['name']:15s}: Img MAE={mae:.4f}, Img MSE={mse:.6f}, Param MAE={param_mae:.4f}")

# Run interactive exploration if we have a trained model
if 'model' in locals() and model is not None:
    interactive_model_exploration(model, image_size=128)

In [None]:
def robustness_analysis(model, data_loader, image_size=128):
    """
    Test model robustness to noise and image degradation
    """
    print("üõ°Ô∏è Model Robustness and Noise Sensitivity Analysis")
    print("=" * 50)
    
    # Get some test images
    val_dataset = data_loader.create_tf_dataset(batch_size=32, validation_split=0.2)[1]
    val_batch = next(iter(val_dataset))
    test_images, test_params = val_batch
    
    # Select a few representative images
    num_test = 5
    original_images = test_images[:num_test]
    original_params = test_params[:num_test]
    
    # Define noise levels and types
    noise_levels = [0.0, 0.05, 0.1, 0.2, 0.3]
    noise_types = [
        ("Gaussian", lambda img, level: img + np.random.normal(0, level, img.shape)),
        ("Salt & Pepper", lambda img, level: add_salt_pepper_noise(img, level)),
        ("Blur", lambda img, level: apply_gaussian_blur(img, level)),
    ]
    
    # Helper functions for noise
    def add_salt_pepper_noise(image, noise_level):
        noisy = image.copy()
        # Salt noise
        num_salt = int(noise_level * image.size * 0.5)
        coords = [np.random.randint(0, i - 1, num_salt) for i in image.shape]
        noisy[coords[0], coords[1], coords[2], coords[3]] = 1
        
        # Pepper noise
        num_pepper = int(noise_level * image.size * 0.5)
        coords = [np.random.randint(0, i - 1, num_pepper) for i in image.shape]
        noisy[coords[0], coords[1], coords[2], coords[3]] = 0
        return np.clip(noisy, 0, 1)
    
    def apply_gaussian_blur(image, blur_level):
        from scipy.ndimage import gaussian_filter
        blurred = np.zeros_like(image)
        for i in range(image.shape[0]):
            blurred[i, :, :, 0] = gaussian_filter(image[i, :, :, 0], sigma=blur_level*2)
        return blurred
    
    # Test robustness
    results = {}
    
    for noise_name, noise_func in noise_types:
        print(f"\\n Testing {noise_name} noise...")
        results[noise_name] = {
            'levels': noise_levels,
            'param_maes': [],
            'recon_maes': [],
            'param_stds': [],
            'recon_stds': []
        }
        
        for noise_level in noise_levels:
            param_errors = []
            recon_errors = []
            
            for test_idx in range(num_test):
                # Add noise to image
                if noise_level == 0:
                    noisy_image = original_images[test_idx:test_idx+1]
                else:
                    noisy_image = noise_func(original_images[test_idx:test_idx+1], noise_level)
                    noisy_image = np.clip(noisy_image, 0, 1)
                
                # Get prediction
                pred_params = model.predict(noisy_image, verbose=0)
                
                # Render reconstruction
                reconstructed = hole_based_renderer(pred_params, image_size=image_size)
                
                # Calculate errors
                param_error = np.mean(np.abs(original_params[test_idx].numpy() - pred_params[0]))
                recon_error = np.mean(np.abs(original_images[test_idx] - reconstructed[0]))
                
                param_errors.append(param_error)
                recon_errors.append(recon_error)
            
            # Store statistics
            results[noise_name]['param_maes'].append(np.mean(param_errors))
            results[noise_name]['recon_maes'].append(np.mean(recon_errors))
            results[noise_name]['param_stds'].append(np.std(param_errors))
            results[noise_name]['recon_stds'].append(np.std(recon_errors))
    
    # Plot results
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # Parameter error vs noise level
    ax = axes[0, 0]
    for noise_name in results:
        ax.errorbar(results[noise_name]['levels'], results[noise_name]['param_maes'],
                   yerr=results[noise_name]['param_stds'], label=noise_name, 
                   marker='o', linewidth=2, capsize=5)
    ax.set_xlabel('Noise Level')
    ax.set_ylabel('Parameter MAE')
    ax.set_title('Parameter Prediction Robustness')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Reconstruction error vs noise level
    ax = axes[0, 1]
    for noise_name in results:
        ax.errorbar(results[noise_name]['levels'], results[noise_name]['recon_maes'],
                   yerr=results[noise_name]['recon_stds'], label=noise_name,
                   marker='s', linewidth=2, capsize=5)
    ax.set_xlabel('Noise Level')
    ax.set_ylabel('Reconstruction MAE')
    ax.set_title('Image Reconstruction Robustness')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Example noisy images
    ax = axes[1, 0]
    example_img = original_images[0:1]
    
    # Create a montage of noisy versions
    noise_examples = []
    for level in [0.0, 0.1, 0.2]:
        for noise_name, noise_func in noise_types[:2]:  # Just Gaussian and S&P
            if level == 0:
                noisy = example_img
            else:
                noisy = noise_func(example_img, level)
                noisy = np.clip(noisy, 0, 1)
            noise_examples.append(noisy[0, :, :, 0])
    
    # Create montage
    montage = np.hstack([np.vstack(noise_examples[i::2]) for i in range(2)])
    ax.imshow(montage, cmap='gray')
    ax.set_title('Example Noise Effects\\n(Left: Gaussian, Right: Salt&Pepper)')
    ax.axis('off')
    
    # Robustness summary
    ax = axes[1, 1]
    ax.axis('off')
    
    # Calculate robustness scores (inverse of degradation)
    summary_text = "üõ°Ô∏è ROBUSTNESS SUMMARY\\n\\n"
    
    for noise_name in results:
        param_degradation = results[noise_name]['param_maes'][-1] / results[noise_name]['param_maes'][0]
        recon_degradation = results[noise_name]['recon_maes'][-1] / results[noise_name]['recon_maes'][0]
        
        summary_text += f"{noise_name}:\\n"
        summary_text += f"  Param degradation: {param_degradation:.2f}x\\n"
        summary_text += f"  Recon degradation: {recon_degradation:.2f}x\\n\\n"
    
    # Overall robustness assessment
    avg_param_degradation = np.mean([results[name]['param_maes'][-1] / results[name]['param_maes'][0] 
                                    for name in results])
    
    if avg_param_degradation < 2.0:
        summary_text += "‚úÖ EXCELLENT robustness"
    elif avg_param_degradation < 3.0:
        summary_text += "‚ö° GOOD robustness"
    elif avg_param_degradation < 5.0:
        summary_text += "‚ö†Ô∏è  MODERATE robustness"
    else:
        summary_text += "‚ùå POOR robustness"
    
    ax.text(0.05, 0.95, summary_text, transform=ax.transAxes,
            verticalalignment='top', fontfamily='monospace', fontsize=10)
    
    plt.tight_layout()
    plt.show()
    
    return results

# Run robustness analysis if we have a trained model
if 'model' in locals() and model is not None:
    robustness_results = robustness_analysis(model, data_loader, image_size=128)

In [None]:
def comprehensive_evaluation_summary(model, history, data_loader):
    """
    Generate a comprehensive summary of all model evaluations
    """
    print("üéØ COMPREHENSIVE MODEL EVALUATION SUMMARY")
    print("=" * 60)
    
    if model is None:
        print("‚ùå No trained model available for evaluation")
        return
    
    # Quick performance test
    val_dataset = data_loader.create_tf_dataset(batch_size=32, validation_split=0.2)[1]
    val_batch = next(iter(val_dataset))
    val_images, val_params = val_batch
    
    # Generate predictions for summary
    predictions = model.predict(val_images[:20], verbose=0)
    reconstructed = hole_based_renderer(predictions, image_size=128)
    
    # Calculate key metrics
    param_mae = np.mean(np.abs(val_params[:20].numpy() - predictions))
    recon_mae = np.mean(np.abs(val_images[:20] - reconstructed))
    
    # Count accuracy
    true_counts = [np.sum(val_params[i, 1::5].numpy() > 0.5) for i in range(20)]
    pred_counts = [np.sum(predictions[i, 1::5] > 0.5) for i in range(20)]
    count_accuracy = np.mean([tc == pc for tc, pc in zip(true_counts, pred_counts)])
    
    # Model complexity
    total_params = model.count_params()
    trainable_params = sum([tf.keras.backend.count_params(w) for w in model.trainable_weights])
    
    print(f"\\nüèóÔ∏è  MODEL ARCHITECTURE:")
    print(f"   Total parameters: {total_params:,}")
    print(f"   Trainable parameters: {trainable_params:,}")
    print(f"   Model size: ~{total_params * 4 / 1024 / 1024:.1f} MB")
    
    if history is not None:
        print(f"\\nüìà TRAINING PERFORMANCE:")
        print(f"   Training epochs: {len(history.history['loss'])}")
        print(f"   Final training loss: {history.history['loss'][-1]:.6f}")
        print(f"   Final validation loss: {history.history['val_loss'][-1]:.6f}")
        print(f"   Best validation loss: {min(history.history['val_loss']):.6f}")
        print(f"   Overfitting gap: {history.history['val_loss'][-1] - history.history['loss'][-1]:.6f}")
    
    print(f"\\nüéØ PREDICTION ACCURACY:")
    print(f"   Parameter MAE: {param_mae:.6f}")
    print(f"   Reconstruction MAE: {recon_mae:.6f}")
    print(f"   Lattice count accuracy: {count_accuracy:.3f}")
    
    # Performance grades
    grades = {}
    
    # Parameter accuracy grade
    if param_mae < 0.1:
        grades['params'] = 'A+'
    elif param_mae < 0.2:
        grades['params'] = 'A'
    elif param_mae < 0.3:
        grades['params'] = 'B'
    elif param_mae < 0.5:
        grades['params'] = 'C'
    else:
        grades['params'] = 'D'
    
    # Reconstruction grade
    if recon_mae < 0.05:
        grades['recon'] = 'A+'
    elif recon_mae < 0.1:
        grades['recon'] = 'A'
    elif recon_mae < 0.15:
        grades['recon'] = 'B'
    elif recon_mae < 0.25:
        grades['recon'] = 'C'
    else:
        grades['recon'] = 'D'
    
    # Count accuracy grade
    if count_accuracy > 0.9:
        grades['count'] = 'A+'
    elif count_accuracy > 0.8:
        grades['count'] = 'A'
    elif count_accuracy > 0.7:
        grades['count'] = 'B'
    elif count_accuracy > 0.6:
        grades['count'] = 'C'
    else:
        grades['count'] = 'D'
    
    print(f"\\nüèÜ PERFORMANCE GRADES:")
    print(f"   Parameter Prediction: {grades['params']}")
    print(f"   Image Reconstruction: {grades['recon']}")
    print(f"   Lattice Count: {grades['count']}")
    
    # Overall grade
    grade_scores = {'A+': 4.3, 'A': 4.0, 'B': 3.0, 'C': 2.0, 'D': 1.0}
    avg_score = np.mean([grade_scores[g] for g in grades.values()])
    
    if avg_score >= 4.0:
        overall = 'A+'
        emoji = 'üåü'
    elif avg_score >= 3.5:
        overall = 'A'
        emoji = '‚≠ê'
    elif avg_score >= 2.5:
        overall = 'B'
        emoji = 'üëç'
    elif avg_score >= 2.0:
        overall = 'C'
        emoji = '‚ö°'
    else:
        overall = 'D'
        emoji = '‚ö†Ô∏è'
    
    print(f"\\n{emoji} OVERALL GRADE: {overall}")
    
    # Recommendations
    print(f"\\nüí° RECOMMENDATIONS:")
    
    if param_mae > 0.3:
        print("   ‚Ä¢ Consider longer training or larger model for better parameter accuracy")
    
    if recon_mae > 0.15:
        print("   ‚Ä¢ Image reconstruction could be improved with more training data")
    
    if count_accuracy < 0.8:
        print("   ‚Ä¢ Lattice count prediction needs improvement - consider adjusting architecture")
    
    if history and len(history.history['loss']) < 20:
        print("   ‚Ä¢ Training may have been too short - consider more epochs")
    
    if history and (history.history['val_loss'][-1] - history.history['loss'][-1]) > 0.1:
        print("   ‚Ä¢ Model is overfitting - consider regularization or more data")
    
    if avg_score >= 3.5:
        print("   ‚úÖ Model is performing well! Ready for deployment.")
    
    print(f"\\n" + "=" * 60)
    print("üéâ EVALUATION COMPLETE!")
    print("=" * 60)

# Run comprehensive summary
if 'model' in locals():
    comprehensive_evaluation_summary(
        model if 'model' in locals() else None,
        history if 'history' in locals() else None,
        data_loader if 'data_loader' in locals() else None
    )