# Moir√© Lattice Reconstruction - Environment Setup

This notebook sets up the environment and tests the ML framework for reconstructing grayscale images as multiple moir√© lattices.

## Project Overview
- **Input**: Grayscale images
- **Output**: Array of Bravais lattices with parameters (base vectors, hole size)
- **Approach**: Both algorithmic (FFT-based) and learned (CNN) methods

In [None]:
# Test the ML environment
import subprocess
import sys

# Run the environment test script
result = subprocess.run([sys.executable, 'test_ml_environment.py'], 
                       capture_output=True, text=True)

print(result.stdout)
if result.stderr:
    print("STDERR:")
    print(result.stderr)

## Basic Imports and Setup

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from scipy import fft
import cv2

print(f"TensorFlow version: {tf.__version__}")
print(f"NumPy version: {np.__version__}")

# Check GPU availability
gpus = tf.config.list_physical_devices('GPU')
print(f"GPUs available: {len(gpus)}")
for gpu in gpus:
    print(f"  {gpu}")

## 1. Algorithmic Approach: FFT-based Peak Extraction

This implements the baseline approach described in the instructions:
1. Compute 2D FFT of target image
2. Find magnitude peaks in frequency domain
3. Map peaks to lattice parameters
4. Render synthetic image
5. Iterative refinement

In [None]:
def create_test_moire_image(size=64, num_lattices=20):
    """Create a synthetic test image with multiple moir√© patterns"""
    x = np.linspace(0, 2*np.pi, size)
    y = np.linspace(0, 2*np.pi, size)
    X, Y = np.meshgrid(x, y)
    
    image = np.zeros((size, size))
    
    # Add multiple lattice patterns
    angles = np.linspace(0, np.pi, num_lattices, endpoint=False)
    frequencies = [2, 3, 4]
    
    for i, (angle, freq) in enumerate(zip(angles, frequencies)):
        # Rotate coordinate system
        X_rot = X * np.cos(angle) - Y * np.sin(angle)
        Y_rot = X * np.sin(angle) + Y * np.cos(angle)
        
        # Add lattice pattern
        amplitude = 0.3 / (i + 1)  # Decreasing amplitude
        pattern = amplitude * np.sin(freq * X_rot) * np.cos(freq * Y_rot)
        image += pattern
    
    # Normalize to [0, 1]
    image = (image - image.min()) / (image.max() - image.min())
    return image

# Create test image
test_image = create_test_moire_image()

# Display
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.imshow(test_image, cmap='gray')
plt.title('Test Moir√© Image')
plt.colorbar()

# Compute and display FFT
fft_result = np.fft.fft2(test_image)
fft_magnitude = np.abs(np.fft.fftshift(fft_result))

plt.subplot(1, 2, 2)
plt.imshow(np.log(fft_magnitude + 1), cmap='viridis')
plt.title('FFT Magnitude (log scale)')
plt.colorbar()

plt.tight_layout()
plt.show()

print(f"Test image shape: {test_image.shape}")
print(f"FFT result shape: {fft_result.shape}")

## 2. Neural Network Approach: CNN Encoder

This implements the learned approach:
- CNN encoder: grayscale image ‚Üí feature vector
- Regressor head: outputs lattice parameters (Œ∏, |k|, œÜ, A)
- Differentiable renderer for training

In [None]:
def create_lattice_cnn(input_shape=(64, 64, 1), num_lattices=8):
    """Create CNN for image-to-lattice parameter regression"""
    
    inputs = tf.keras.Input(shape=input_shape)
    
    # CNN Encoder
    x = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    x = tf.keras.layers.MaxPooling2D((2, 2))(x)
    
    x = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = tf.keras.layers.MaxPooling2D((2, 2))(x)
    
    x = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = tf.keras.layers.MaxPooling2D((2, 2))(x)
    
    x = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(x)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    
    # Dense layers
    x = tf.keras.layers.Dense(512, activation='relu')(x)
    x = tf.keras.layers.Dropout(0.5)(x)
    x = tf.keras.layers.Dense(256, activation='relu')(x)
    x = tf.keras.layers.Dropout(0.3)(x)
    
    # Output: num_lattices √ó 4 parameters (theta, |k|, phi, amplitude)
    outputs = tf.keras.layers.Dense(num_lattices * 4, activation='linear')(x)
    
    model = tf.keras.Model(inputs, outputs)
    return model

# Create model
model = create_lattice_cnn()
model.compile(optimizer='adam', loss='mse', metrics=['mae'])

# Model summary
model.summary()

# Test with dummy data
dummy_input = np.random.random((1, 64, 64, 1))
output = model.predict(dummy_input, verbose=0)

print(f"\nModel test:")
print(f"Input shape: {dummy_input.shape}")
print(f"Output shape: {output.shape}")
print(f"Output represents: {output.shape[1]//4} lattices √ó 4 parameters each")

## 3. Differentiable Renderer

This implements a differentiable renderer that can synthesize images from lattice parameters, enabling end-to-end training.

In [None]:
@tf.function
def differentiable_renderer(lattice_params, image_size=64):
    """
    Render an image from lattice parameters
    
    Args:
        lattice_params: tensor of shape (batch_size, num_lattices * 4)
                       where each lattice has (theta, |k|, phi, amplitude)
        image_size: size of output image
    
    Returns:
        rendered images of shape (batch_size, image_size, image_size, 1)
    """
    batch_size = tf.shape(lattice_params)[0]
    num_lattices = tf.shape(lattice_params)[1] // 4
    
    # Reshape parameters
    params = tf.reshape(lattice_params, (batch_size, num_lattices, 4))
    
    # Create coordinate grid
    x = tf.linspace(0., 2*np.pi, image_size)
    y = tf.linspace(0., 2*np.pi, image_size)
    X, Y = tf.meshgrid(x, y)
    coords = tf.stack([X, Y], axis=-1)  # (image_size, image_size, 2)
    
    # Initialize output
    images = tf.zeros((batch_size, image_size, image_size, 1))
    
    # Add each lattice contribution
    for i in range(num_lattices):
        theta = params[:, i, 0]  # (batch_size,)
        k_mag = params[:, i, 1]  # (batch_size,)
        phi = params[:, i, 2]    # (batch_size,)
        amplitude = params[:, i, 3]  # (batch_size,)
        
        # Compute k vector
        kx = k_mag * tf.cos(theta)  # (batch_size,)
        ky = k_mag * tf.sin(theta)  # (batch_size,)
        
        # Broadcast for computation
        kx = kx[:, None, None, None]  # (batch_size, 1, 1, 1)
        ky = ky[:, None, None, None]  # (batch_size, 1, 1, 1)
        phi = phi[:, None, None, None]  # (batch_size, 1, 1, 1)
        amplitude = amplitude[:, None, None, None]  # (batch_size, 1, 1, 1)
        
        # Compute phase
        phase = kx * X[None, :, :, None] + ky * Y[None, :, :, None] + phi
        
        # Add lattice contribution
        contribution = amplitude * tf.cos(phase)
        images += contribution
    
    return images

# Test the renderer
test_params = tf.constant([[1.0, 2.0, 0.0, 0.5,  # First lattice: theta=1, |k|=2, phi=0, amp=0.5
                           0.5, 3.0, np.pi/4, 0.3]])  # Second lattice

rendered = differentiable_renderer(test_params, image_size=64)

plt.figure(figsize=(6, 5))
plt.imshow(rendered[0, :, :, 0], cmap='gray')
plt.title('Rendered Image from Lattice Parameters')
plt.colorbar()
plt.show()

print(f"Rendered image shape: {rendered.shape}")
print(f"Parameters used: {test_params.numpy()}")

## Next Steps

1. **Generate Training Data**: Create synthetic datasets with known lattice parameters
2. **Train the Model**: Use the differentiable renderer for end-to-end training
3. **Implement FFT Baseline**: For comparison and initialization
4. **Evaluation**: Compare reconstruction quality and parameter recovery accuracy
5. **Real Image Testing**: Test on actual grayscale images

The environment is now ready for moir√© lattice reconstruction research!

In [None]:
import h5py
import os
from datetime import datetime

def lattice_points_in_region(a1, a2, region_size):
    """
    Generate lattice points within a square region
    
    Args:
        a1, a2: base vectors (2D)
        region_size: size of the square region
    
    Returns:
        array of lattice points (N, 2)
    """
    # Determine how many lattice cells we need in each direction
    max_n = int(np.ceil(region_size / min(np.linalg.norm(a1), np.linalg.norm(a2)))) + 2
    
    points = []
    for n in range(-max_n, max_n + 1):
        for m in range(-max_n, max_n + 1):
            point = n * a1 + m * a2
            # Keep points within the region (with some margin)
            if (-0.1 <= point[0] <= region_size + 0.1 and 
                -0.1 <= point[1] <= region_size + 0.1):
                points.append(point)
    
    return np.array(points)

@tf.function
def hole_based_renderer(lattice_params, image_size=64):
    """
    Render an image with holes at lattice points (numerically stable version)
    
    Args:
        lattice_params: tensor of shape (batch_size, num_lattices * 5)
                       where each lattice has (theta, spacing, phase_x, phase_y, hole_radius)
        image_size: size of output image
    
    Returns:
        rendered images of shape (batch_size, image_size, image_size, 1)
    """
    batch_size = tf.shape(lattice_params)[0]
    num_lattices = tf.shape(lattice_params)[1] // 5
    
    # Reshape parameters
    params = tf.reshape(lattice_params, (batch_size, num_lattices, 5))
    
    # Create coordinate grid
    x = tf.linspace(0., float(image_size), image_size)
    y = tf.linspace(0., float(image_size), image_size)
    X, Y = tf.meshgrid(x, y)
    
    # Initialize output - start with white background
    images = tf.ones((batch_size, image_size, image_size, 1))
    
    # Process each lattice
    for i in range(num_lattices):
        theta = params[:, i, 0]      # rotation angle
        spacing = params[:, i, 1]    # lattice spacing
        phase_x = params[:, i, 2]    # phase offset in x
        phase_y = params[:, i, 3]    # phase offset in y
        hole_radius = params[:, i, 4]  # hole radius
        
        # Skip if spacing is 0 or too small (empty lattice slot)
        mask = tf.cast(spacing > 0.1, tf.float32)[:, None, None, None]
        
        # Clamp spacing to avoid division issues
        spacing = tf.maximum(spacing, 0.1)
        
        # Create lattice pattern using cosine waves
        cos_theta = tf.cos(theta)[:, None, None, None]
        sin_theta = tf.sin(theta)[:, None, None, None]
        spacing_expanded = spacing[:, None, None, None]
        
        # Rotated coordinates
        X_rot = X[None, :, :, None] * cos_theta + Y[None, :, :, None] * sin_theta
        Y_rot = -X[None, :, :, None] * sin_theta + Y[None, :, :, None] * cos_theta
        
        # Create lattice pattern with clamped division
        phase_x_expanded = phase_x[:, None, None, None]
        phase_y_expanded = phase_y[:, None, None, None]
        
        # Add small epsilon to avoid division by zero
        safe_spacing = tf.maximum(spacing_expanded, 1e-6)
        
        pattern_x = tf.cos(2 * np.pi * X_rot / safe_spacing + phase_x_expanded)
        pattern_y = tf.cos(2 * np.pi * Y_rot / safe_spacing + phase_y_expanded)
        
        # Combine patterns - high values at lattice points
        lattice_indicator = (pattern_x + 1) * (pattern_y + 1) / 4
        
        # Create holes: dark spots where lattice_indicator is high
        # Use tanh instead of sigmoid for better numerical stability
        hole_pattern = 1.0 - 0.5 * (tf.nn.tanh(10.0 * (lattice_indicator - 0.9)) + 1.0)
        
        # Ensure hole_pattern is in valid range
        hole_pattern = tf.clip_by_value(hole_pattern, 0.0, 1.0)
        
        # Apply mask (only for non-zero lattices)
        images = images * (1 - mask) + images * hole_pattern * mask
    
    # Final clipping to ensure valid range
    images = tf.clip_by_value(images, 0.0, 1.0)
    
    # Check for NaN and replace with default value if any
    images = tf.where(tf.math.is_nan(images), tf.ones_like(images) * 0.5, images)
    
    return images

# Test the hole-based renderer
print("Testing hole-based renderer...")
test_params_holes = tf.constant([[
    0.0, 8.0, 0.0, 0.0, 2.0,      # Lattice 1: theta=0, spacing=8, no phase, hole_radius=2
    np.pi/4, 10.0, 0.0, 0.0, 1.5   # Lattice 2: theta=45¬∞, spacing=10, hole_radius=1.5
]], dtype=tf.float32)

rendered_holes = hole_based_renderer(test_params_holes, image_size=64)

# Debug information
print(f"Rendered image shape: {rendered_holes.shape}")
print(f"Rendered image range: [{tf.reduce_min(rendered_holes):.3f}, {tf.reduce_max(rendered_holes):.3f}]")
print(f"Rendered image mean: {tf.reduce_mean(rendered_holes):.3f}")
print(f"Number of NaN values: {tf.reduce_sum(tf.cast(tf.math.is_nan(rendered_holes), tf.int32))}")

plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.imshow(rendered_holes[0, :, :, 0], cmap='gray', vmin=0, vmax=1)
plt.title('Hole-based Renderer Output')
plt.colorbar()

# Compare with original continuous renderer for reference
test_params_continuous = tf.constant([[0.0, 1.0, 0.0, 0.5,  # Convert to old format
                                     np.pi/4, 0.8, 0.0, 0.3]])
rendered_continuous = differentiable_renderer(test_params_continuous, image_size=64)

plt.subplot(1, 3, 2)
plt.imshow(rendered_continuous[0, :, :, 0], cmap='gray')
plt.title('Original Continuous Renderer')
plt.colorbar()

# Create a simple test pattern to verify renderer works
simple_test = tf.constant([[0.0, 16.0, 0.0, 0.0, 4.0,  # Simple grid
                           0.0, 0.0, 0.0, 0.0, 0.0]], dtype=tf.float32)  # Empty lattice
simple_rendered = hole_based_renderer(simple_test, image_size=64)

plt.subplot(1, 3, 3)
plt.imshow(simple_rendered[0, :, :, 0], cmap='gray', vmin=0, vmax=1)
plt.title('Simple Test Pattern')
plt.colorbar()

plt.tight_layout()
plt.show()

print(f"Hole-based renderer output range: [{tf.reduce_min(rendered_holes):.3f}, {tf.reduce_max(rendered_holes):.3f}]")

In [None]:
## 5. Data Loading and Model Training Infrastructure

class MoireDataLoader:
    """Data loader for moir√© lattice training data"""
    
    def __init__(self, hdf5_path):
        self.hdf5_path = hdf5_path
        self.file = None
        self._load_metadata()
    
    def _load_metadata(self):
        """Load dataset metadata"""
        with h5py.File(self.hdf5_path, 'r') as f:
            self.image_size = f.attrs['image_size']
            self.max_lattices = f.attrs['max_lattices']
            self.num_samples = f.attrs['num_samples']
            self.parameter_format = f.attrs['parameter_format']
            print(f"Dataset: {self.num_samples} samples, {self.image_size}x{self.image_size}, max {self.max_lattices} lattices")
    
    def create_tf_dataset(self, batch_size=32, shuffle=True, validation_split=0.2):
        """Create TensorFlow datasets for training and validation"""
        
        def data_generator():
            with h5py.File(self.hdf5_path, 'r') as f:
                indices = np.arange(self.num_samples)
                if shuffle:
                    np.random.shuffle(indices)
                
                for idx in indices:
                    image = f['images'][idx]
                    params = f['parameters'][idx]
                    yield image, params
        
        # Create dataset
        dataset = tf.data.Dataset.from_generator(
            data_generator,
            output_signature=(
                tf.TensorSpec(shape=(self.image_size, self.image_size, 1), dtype=tf.float32),
                tf.TensorSpec(shape=(self.max_lattices * 5,), dtype=tf.float32)
            )
        )
        
        # Split into train/validation
        val_size = int(self.num_samples * validation_split)
        train_size = self.num_samples - val_size
        
        train_dataset = dataset.take(train_size).batch(batch_size).prefetch(tf.data.AUTOTUNE)
        val_dataset = dataset.skip(train_size).batch(batch_size).prefetch(tf.data.AUTOTUNE)
        
        return train_dataset, val_dataset

def create_improved_lattice_cnn(input_shape=(64, 64, 1), max_lattices=8, param_per_lattice=5):
    """
    Create an improved CNN for image-to-lattice parameter regression
    
    Args:
        input_shape: input image shape
        max_lattices: maximum number of lattices to predict
        param_per_lattice: parameters per lattice (theta, spacing, phase_x, phase_y, hole_radius)
    """
    inputs = tf.keras.Input(shape=input_shape)
    
    # CNN Encoder with residual connections
    x = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same')(x)
    x1 = tf.keras.layers.MaxPooling2D((2, 2))(x)  # 32x32
    
    x = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x1)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x2 = tf.keras.layers.MaxPooling2D((2, 2))(x)  # 16x16
    
    x = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x2)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x3 = tf.keras.layers.MaxPooling2D((2, 2))(x)  # 8x8
    
    x = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(x3)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(x)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    
    # Dense layers with attention-like mechanism
    x = tf.keras.layers.Dense(512, activation='relu')(x)
    x = tf.keras.layers.Dropout(0.5)(x)
    x = tf.keras.layers.Dense(256, activation='relu')(x)
    x = tf.keras.layers.Dropout(0.3)(x)
    
    # Lattice-specific heads (one for each lattice)
    lattice_outputs = []
    for i in range(max_lattices):
        # Each lattice gets its own small MLP
        lattice_features = tf.keras.layers.Dense(64, activation='relu', name=f'lattice_{i}_features')(x)
        lattice_params = tf.keras.layers.Dense(param_per_lattice, activation='linear', 
                                             name=f'lattice_{i}_params')(lattice_features)
        lattice_outputs.append(lattice_params)
    
    # Concatenate all lattice parameters
    outputs = tf.keras.layers.Concatenate()(lattice_outputs)
    
    model = tf.keras.Model(inputs, outputs)
    return model

# Custom loss function that handles variable number of lattices
def moire_reconstruction_loss(y_true, y_pred, image_size=64, alpha=1.0, beta=0.1):
    """
    Custom loss combining parameter MSE and reconstruction error
    
    Args:
        y_true: true parameters (batch_size, max_lattices * 5)
        y_pred: predicted parameters (batch_size, max_lattices * 5)
        alpha: weight for reconstruction loss
        beta: weight for parameter regularization
    """
    # Parameter MSE loss
    param_loss = tf.reduce_mean(tf.square(y_true - y_pred))
    
    # Reconstruction loss (render images and compare)
    # Note: This makes training slower but more accurate
    pred_images = hole_based_renderer(y_pred, image_size=image_size)
    true_images = hole_based_renderer(y_true, image_size=image_size)
    reconstruction_loss = tf.reduce_mean(tf.square(pred_images - true_images))
    
    # Regularization: encourage sparse lattices (many zero parameters)
    regularization_loss = tf.reduce_mean(tf.abs(y_pred))
    
    return param_loss + alpha * reconstruction_loss + beta * regularization_loss

# Create the improved model
improved_model = create_improved_lattice_cnn(max_lattices=8)

# Custom optimizer with learning rate scheduling
initial_learning_rate = 0.001
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate,
    decay_steps=1000,
    decay_rate=0.96,
    staircase=True
)

improved_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule),
    loss=lambda y_true, y_pred: moire_reconstruction_loss(y_true, y_pred, image_size=64),
    metrics=['mae']
)

improved_model.summary()

print("\nImproved model created with:")
print("- Residual connections and batch normalization")
print("- Lattice-specific output heads")
print("- Custom reconstruction loss")
print("- Learning rate scheduling")

In [None]:
## 6. Training Pipeline and Evaluation

def train_moire_model(model, train_dataset, val_dataset, epochs=50, save_path='best_moire_model.h5'):
    """Train the moir√© lattice reconstruction model"""
    
    # Callbacks
    callbacks = [
        tf.keras.callbacks.ModelCheckpoint(
            save_path, 
            save_best_only=True, 
            monitor='val_loss',
            verbose=1
        ),
        tf.keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=5,
            min_lr=1e-6,
            verbose=1
        ),
        tf.keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=10,
            restore_best_weights=True,
            verbose=1
        )
    ]
    
    # Train model
    history = model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=epochs,
        callbacks=callbacks,
        verbose=1
    )
    
    return history

def evaluate_reconstruction(model, test_images, test_params, num_samples=5):
    """Evaluate model reconstruction quality"""
    
    # Predict parameters
    pred_params = model.predict(test_images[:num_samples])
    
    # Render images from predicted and true parameters
    pred_images = hole_based_renderer(pred_params, image_size=test_images.shape[1])
    true_images = hole_based_renderer(test_params[:num_samples], image_size=test_images.shape[1])
    
    # Calculate metrics
    mse_reconstruction = tf.reduce_mean(tf.square(pred_images - test_images[:num_samples]))
    mse_parameters = tf.reduce_mean(tf.square(pred_params - test_params[:num_samples]))
    
    # Visualize results
    fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    for i in range(num_samples):
        # Original image
        axes[i, 0].imshow(test_images[i, :, :, 0], cmap='gray')
        axes[i, 0].set_title(f'Original Image {i+1}')
        axes[i, 0].axis('off')
        
        # Predicted reconstruction
        axes[i, 1].imshow(pred_images[i, :, :, 0], cmap='gray')
        axes[i, 1].set_title(f'Predicted Reconstruction')
        axes[i, 1].axis('off')
        
        # True reconstruction (from true parameters)
        axes[i, 2].imshow(true_images[i, :, :, 0], cmap='gray')
        axes[i, 2].set_title(f'True Reconstruction')
        axes[i, 2].axis('off')
        
        # Difference
        diff = np.abs(pred_images[i, :, :, 0] - test_images[i, :, :, 0])
        axes[i, 3].imshow(diff, cmap='hot')
        axes[i, 3].set_title(f'Abs Difference')
        axes[i, 3].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print(f"Reconstruction MSE: {mse_reconstruction:.6f}")
    print(f"Parameter MSE: {mse_parameters:.6f}")
    
    return mse_reconstruction, mse_parameters

## 7. FFT Baseline Implementation

def fft_peak_extraction(image, num_peaks=8, min_distance=3):
    """
    Extract peaks from 2D FFT for lattice parameter estimation
    
    Args:
        image: input grayscale image
        num_peaks: maximum number of peaks to extract
        min_distance: minimum distance between peaks
    
    Returns:
        peak_params: array of lattice parameters (theta, |k|, phi, amplitude)
    """
    # Compute 2D FFT
    fft_result = np.fft.fft2(image)
    fft_shifted = np.fft.fftshift(fft_result)
    magnitude = np.abs(fft_shifted)
    phase = np.angle(fft_shifted)
    
    # Find peaks in magnitude spectrum (exclude DC component)
    center = magnitude.shape[0] // 2
    magnitude_copy = magnitude.copy()
    magnitude_copy[center-2:center+3, center-2:center+3] = 0  # Remove DC and nearby
    
    # Simple peak detection using local maxima
    from scipy.ndimage import maximum_filter
    
    # Apply maximum filter to find local maxima
    neighborhood_size = max(3, min_distance)
    local_maxima = maximum_filter(magnitude_copy, size=neighborhood_size)
    
    # Find positions where original equals local maximum (these are peaks)
    peak_mask = (magnitude_copy == local_maxima) & (magnitude_copy > np.max(magnitude_copy) * 0.1)
    peak_coords = np.where(peak_mask)
    
    if len(peak_coords[0]) == 0:
        # Fallback: find top values manually
        flat_indices = np.argsort(magnitude_copy.flatten())[::-1]
        coords = np.unravel_index(flat_indices[:num_peaks*2], magnitude_copy.shape)
        peak_coords = (coords[0], coords[1])
    
    # Get peak coordinates as (y, x) pairs
    peaks = list(zip(peak_coords[0], peak_coords[1]))
    
    # Sort by magnitude and take top peaks
    peak_magnitudes = [magnitude_copy[y, x] for y, x in peaks]
    sorted_peaks = sorted(zip(peaks, peak_magnitudes), key=lambda x: x[1], reverse=True)
    
    # Extract top peaks
    top_peaks = [peak[0] for peak in sorted_peaks[:num_peaks]]
    
    lattice_params = []
    for peak in top_peaks:
        y, x = peak
        # Convert to k-space coordinates (relative to center)
        ky = (y - center) * 2 * np.pi / image.shape[0]
        kx = (x - center) * 2 * np.pi / image.shape[1]
        
        # Calculate parameters
        k_magnitude = np.sqrt(kx**2 + ky**2)
        if k_magnitude > 0:
            theta = np.arctan2(ky, kx)
            amplitude = magnitude[y, x] / np.max(magnitude)
            phi = phase[y, x]
            
            lattice_params.append([theta, k_magnitude, phi, amplitude])
    
    return np.array(lattice_params)

def fft_baseline_reconstruction(image, num_lattices=8):
    """
    Baseline FFT-based reconstruction method
    """
    # Debug: print image stats
    print(f"Input image stats: min={np.min(image):.3f}, max={np.max(image):.3f}, mean={np.mean(image):.3f}")
    
    # Extract FFT peaks
    fft_params = fft_peak_extraction(image, num_peaks=num_lattices)
    
    if len(fft_params) == 0:
        print("Warning: No FFT peaks found, returning zeros")
        return np.zeros((num_lattices, 5))
    
    # Debug: print FFT analysis
    fft_result = np.fft.fft2(image)
    fft_magnitude = np.abs(np.fft.fftshift(fft_result))
    print(f"FFT magnitude range: [{np.min(fft_magnitude):.1f}, {np.max(fft_magnitude):.1f}]")
    print(f"Number of FFT peaks found: {len(fft_params)}")
    
    # Convert to our parameter format (theta, spacing, phase_x, phase_y, hole_radius)
    converted_params = []
    for i, params in enumerate(fft_params):
        theta, k_mag, phi, amplitude = params
        if k_mag > 0 and not np.isnan(k_mag) and not np.isnan(theta):
            spacing = 2 * np.pi / k_mag
            # Clamp spacing to reasonable range
            spacing = np.clip(spacing, 4.0, 32.0)
            
            # Estimate hole radius from amplitude, make it reasonable
            hole_radius = np.clip(amplitude * spacing * 0.15, 0.5, spacing / 4)
            
            # Normalize phase to [0, 2œÄ]
            phase_x = (phi + np.pi) % (2 * np.pi)
            phase_y = 0.0  # Set phase_y to 0 for simplicity
            
            converted_params.append([theta, spacing, phase_x, phase_y, hole_radius])
            print(f"  Peak {i}: k_mag={k_mag:.3f} ‚Üí spacing={spacing:.1f}, amp={amplitude:.3f} ‚Üí radius={hole_radius:.2f}")
    
    # If no valid params found, create some default ones
    if len(converted_params) == 0:
        print("Warning: No valid FFT parameters, using defaults")
        converted_params = [[0, 8.0, 0, 0, 1.0]]  # Default lattice
    
    # Pad to max_lattices with zeros
    while len(converted_params) < num_lattices:
        converted_params.append([0, 0, 0, 0, 0])  # Zero params = no contribution
    
    return np.array(converted_params[:num_lattices])

# Test FFT baseline with better debugging
print("Testing FFT baseline with improved debugging...")
test_image_2d = test_image  # Use the test image from earlier

# Ensure image is properly normalized
if np.max(test_image_2d) > 1.0:
    test_image_2d = test_image_2d / np.max(test_image_2d)

fft_baseline_params = fft_baseline_reconstruction(test_image_2d, num_lattices=8)

print("\nFFT extracted parameters (after conversion):")
for i, params in enumerate(fft_baseline_params):
    if params[1] > 0:  # Only show non-zero lattices
        print(f"  Lattice {i}: Œ∏={params[0]:.2f}, spacing={params[1]:.1f}, "
              f"phases=({params[2]:.2f},{params[3]:.2f}), radius={params[4]:.2f}")

# Render FFT baseline reconstruction
fft_params_tf = tf.constant([fft_baseline_params.flatten()], dtype=tf.float32)
print(f"\nFFT params tensor shape: {fft_params_tf.shape}")
print(f"FFT params tensor: contains NaN? {tf.reduce_any(tf.math.is_nan(fft_params_tf))}")

fft_reconstruction = hole_based_renderer(fft_params_tf, image_size=64)
print(f"FFT reconstruction shape: {fft_reconstruction.shape}")
print(f"FFT reconstruction range: [{tf.reduce_min(fft_reconstruction):.3f}, {tf.reduce_max(fft_reconstruction):.3f}]")
print(f"FFT reconstruction contains NaN? {tf.reduce_any(tf.math.is_nan(fft_reconstruction))}")

plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.imshow(test_image_2d, cmap='gray', vmin=0, vmax=1)
plt.title('Original Test Image')
plt.colorbar()

plt.subplot(1, 3, 2)
rendered_image = fft_reconstruction[0, :, :, 0].numpy()
if np.any(np.isnan(rendered_image)):
    print("Warning: Rendered image contains NaN values!")
    rendered_image = np.nan_to_num(rendered_image, nan=0.5)
plt.imshow(rendered_image, cmap='gray', vmin=0, vmax=1)
plt.title('FFT Baseline Reconstruction')
plt.colorbar()

plt.subplot(1, 3, 3)
error_image = np.abs(test_image_2d - rendered_image)
plt.imshow(error_image, cmap='hot', vmin=0, vmax=1)
plt.title('Reconstruction Error')
plt.colorbar()

plt.tight_layout()
plt.show()

# Calculate MSE, handling NaN values
if np.any(np.isnan(rendered_image)):
    fft_mse = np.nan
else:
    fft_mse = np.mean((test_image_2d - rendered_image)**2)
print(f"\nFFT Baseline MSE: {fft_mse:.6f}")

In [None]:
## 8. Complete ML Pipeline Execution

# This cell provides the complete workflow to train and evaluate the model
# Uncomment and run sections as needed

def run_complete_pipeline():
    """
    Complete pipeline for moir√© lattice reconstruction
    """
    print("üîß Starting Complete Moir√© Lattice ML Pipeline")
    print("=" * 60)
    
    # Step 1: Generate training data
    print("\nüìä Step 1: Generating training dataset...")
    dataset_path = 'moire_training_data.h5'
    
    # Create training dataset (adjust num_samples based on your needs)
    # Start small for testing, then increase for full training
    create_training_dataset(
        num_samples=1000,  # Start with 1000 for testing
        image_size=64,
        max_lattices=8,
        save_path=dataset_path,
        batch_size=50
    )
    
    # Step 2: Load data
    print("\nüìÇ Step 2: Loading data...")
    data_loader = MoireDataLoader(dataset_path)
    train_dataset, val_dataset = data_loader.create_tf_dataset(batch_size=16, validation_split=0.2)
    
    # Step 3: Train model
    print("\nüß† Step 3: Training model...")
    model_path = 'best_moire_model.h5'
    
    # Train for a few epochs first to test
    history = train_moire_model(
        improved_model, 
        train_dataset, 
        val_dataset, 
        epochs=5,  # Start with 5 epochs for testing
        save_path=model_path
    )
    
    # Step 4: Evaluate model
    print("\nüìà Step 4: Evaluating model...")
    
    # Get test samples
    test_batch = next(iter(val_dataset))
    test_images, test_params = test_batch
    
    # Evaluate reconstruction quality
    mse_recon, mse_params = evaluate_reconstruction(
        improved_model, 
        test_images, 
        test_params, 
        num_samples=3
    )
    
    # Compare with FFT baseline
    print("\nüîç Step 5: Comparing with FFT baseline...")
    fft_errors = []
    for i in range(min(3, len(test_images))):
        fft_params = fft_baseline_reconstruction(test_images[i, :, :, 0].numpy())
        fft_params_tf = tf.constant([fft_params.flatten()], dtype=tf.float32)
        fft_recon = hole_based_renderer(fft_params_tf, image_size=64)
        
        # Check for NaN values
        if tf.reduce_any(tf.math.is_nan(fft_recon)):
            print(f"Warning: FFT reconstruction {i} contains NaN values")
            fft_error = np.nan
        else:
            fft_error = tf.reduce_mean(tf.square(fft_recon - test_images[i:i+1]))
            fft_errors.append(fft_error.numpy())
    
    # Filter out NaN values for averaging
    valid_fft_errors = [e for e in fft_errors if not np.isnan(e)]
    avg_fft_error = np.mean(valid_fft_errors) if valid_fft_errors else np.nan
    
    print(f"\nResults Summary:")
    print(f"CNN Model - Reconstruction MSE: {mse_recon:.6f}")
    print(f"CNN Model - Parameter MSE: {mse_params:.6f}")
    print(f"FFT Baseline - Average MSE: {avg_fft_error:.6f}")
    print(f"FFT Baseline - Valid samples: {len(valid_fft_errors)}/{len(fft_errors)}")
    
    return history, improved_model, data_loader

# Instructions for running the pipeline
print("üìã COMPLETE MOIRE LATTICE ML PIPELINE")
print("=" * 50)
print()
print("To run the complete pipeline, execute:")
print("    history, model, loader = run_complete_pipeline()")
print()
print("Pipeline includes:")
print("‚úÖ 1. Training data generation (1000 samples)")
print("‚úÖ 2. Data loading and preprocessing")
print("‚úÖ 3. Model training (5 epochs for testing)")
print("‚úÖ 4. Model evaluation and visualization")
print("‚úÖ 5. FFT baseline comparison")
print()
print("For production training:")
print("‚Ä¢ Increase num_samples to 10000+ in create_training_dataset()")
print("‚Ä¢ Increase epochs to 50+ in train_moire_model()")
print("‚Ä¢ Monitor training on validation set")
print()
print("Next steps after testing:")
print("üî¨ 1. Test on real grayscale images")
print("üîß 2. Fine-tune hyperparameters")
print("üìä 3. Add more evaluation metrics")
print("üöÄ 4. Deploy for inference")

# Create a smaller test run function for quick testing
def quick_test_pipeline():
    """Quick test with minimal data for debugging"""
    print("üß™ Quick Test Pipeline (100 samples, 2 epochs)")
    
    # Generate minimal test data
    create_training_dataset(num_samples=100, image_size=64, max_lattices=4, 
                          save_path='test_data.h5', batch_size=20)
    
    # Load and train
    loader = MoireDataLoader('test_data.h5')
    train_ds, val_ds = loader.create_tf_dataset(batch_size=8, validation_split=0.3)
    
    # Quick training
    test_model = create_improved_lattice_cnn(max_lattices=4)
    test_model.compile(optimizer='adam', loss='mse', metrics=['mae'])
    
    history = test_model.fit(train_ds, validation_data=val_ds, epochs=2, verbose=1)
    
    print("‚úÖ Quick test completed successfully!")
    return test_model, loader

print("\nFor quick testing, run: quick_test_pipeline()")
quick_test_pipeline()

In [None]:
## 4. Training Data Generation Pipeline

def generate_random_lattice_params(num_lattices_range=(2, 8), image_size=64):
    """
    Generate random lattice parameters for training data
    
    Returns:
        lattice_params: array of shape (num_lattices, 5) with parameters
                       (theta, spacing, phase_x, phase_y, hole_radius)
    """
    num_lattices = np.random.randint(num_lattices_range[0], num_lattices_range[1] + 1)
    
    params = []
    for _ in range(num_lattices):
        # Rotation angle [0, 2œÄ)
        theta = np.random.uniform(0, 2 * np.pi)
        
        # Lattice spacing [4, image_size/2] to ensure visible patterns
        spacing = np.random.uniform(4, image_size // 2)
        
        # Phase offsets [0, 2œÄ)
        phase_x = np.random.uniform(0, 2 * np.pi)
        phase_y = np.random.uniform(0, 2 * np.pi)
        
        # Hole radius [0.5, spacing/3] to ensure holes don't overlap too much
        hole_radius = np.random.uniform(0.5, min(spacing / 3, 4.0))
        
        params.append([theta, spacing, phase_x, phase_y, hole_radius])
    
    # Pad with zeros if needed to reach maximum number of lattices
    max_lattices = num_lattices_range[1]
    while len(params) < max_lattices:
        params.append([0, 0, 0, 0, 0])  # Zero params = no contribution
    
    return np.array(params)

def create_training_dataset(num_samples=10000, image_size=64, max_lattices=8, 
                          save_path='moire_training_data.h5', batch_size=100):
    """
    Create a large training dataset and save to HDF5 file
    
    Args:
        num_samples: number of training examples
        image_size: size of generated images
        max_lattices: maximum number of lattices per image
        save_path: path to save HDF5 file
        batch_size: batch size for generation (to manage memory)
    """
    print(f"Creating training dataset with {num_samples} samples...")
    print(f"Image size: {image_size}x{image_size}, Max lattices: {max_lattices}")
    
    # Create HDF5 file
    with h5py.File(save_path, 'w') as f:
        # Create datasets
        images_ds = f.create_dataset('images', (num_samples, image_size, image_size, 1), 
                                   dtype=np.float32, compression='gzip')
        params_ds = f.create_dataset('parameters', (num_samples, max_lattices * 5), 
                                   dtype=np.float32, compression='gzip')
        num_lattices_ds = f.create_dataset('num_lattices', (num_samples,), 
                                         dtype=np.int32, compression='gzip')
        
        # Add metadata
        f.attrs['image_size'] = image_size
        f.attrs['max_lattices'] = max_lattices
        f.attrs['num_samples'] = num_samples
        f.attrs['created_date'] = datetime.now().isoformat()
        f.attrs['parameter_format'] = 'theta, spacing, phase_x, phase_y, hole_radius'
        
        # Generate data in batches
        for batch_start in range(0, num_samples, batch_size):
            batch_end = min(batch_start + batch_size, num_samples)
            current_batch_size = batch_end - batch_start
            
            print(f"Generating batch {batch_start//batch_size + 1}/{(num_samples-1)//batch_size + 1}...")
            
            batch_images = []
            batch_params = []
            batch_num_lattices = []
            
            for i in range(current_batch_size):
                # Generate random parameters
                lattice_params = generate_random_lattice_params((2, max_lattices), image_size)
                num_actual_lattices = np.sum(lattice_params[:, 1] > 0)  # Count non-zero spacings
                
                # Flatten parameters for model input
                flat_params = lattice_params.flatten()
                
                # Render image using TensorFlow (convert to TF format)
                tf_params = tf.constant([flat_params], dtype=tf.float32)
                rendered = hole_based_renderer(tf_params, image_size=image_size)
                image = rendered[0].numpy()
                
                # Add some noise to make training more robust
                noise_level = np.random.uniform(0, 0.05)
                image += np.random.normal(0, noise_level, image.shape)
                image = np.clip(image, 0, 1)
                
                batch_images.append(image)
                batch_params.append(flat_params)
                batch_num_lattices.append(num_actual_lattices)
            
            # Save batch to HDF5
            images_ds[batch_start:batch_end] = np.array(batch_images)
            params_ds[batch_start:batch_end] = np.array(batch_params)
            num_lattices_ds[batch_start:batch_end] = np.array(batch_num_lattices)
    
    print(f"Dataset saved to {save_path}")
    print(f"File size: {os.path.getsize(save_path) / (1024**2):.1f} MB")

# Test data generation with a small sample
print("Testing data generation...")
test_params = generate_random_lattice_params((2, 4), image_size=64)
print(f"Generated parameters shape: {test_params.shape}")
print(f"Sample parameters:")
for i, params in enumerate(test_params):
    if params[1] > 0:  # Only show non-zero lattices
        print(f"  Lattice {i}: Œ∏={params[0]:.2f}, spacing={params[1]:.1f}, "
              f"phases=({params[2]:.2f},{params[3]:.2f}), radius={params[4]:.1f}")

# Render test image
tf_test_params = tf.constant([test_params.flatten()], dtype=tf.float32)
test_rendered = hole_based_renderer(tf_test_params, image_size=64)

plt.figure(figsize=(8, 6))
plt.imshow(test_rendered[0, :, :, 0], cmap='gray')
plt.title('Test Generated Training Sample')
plt.colorbar()
plt.show()

print(f"Test image range: [{tf.reduce_min(test_rendered):.3f}, {tf.reduce_max(test_rendered):.3f}]")

In [None]:
## 8.5. Alternative: On-the-Fly Data Generation

class OnTheFlyMoireDataGenerator:
    """
    Generate moir√© lattice training data on-the-fly during training
    This saves disk space and allows infinite data variation
    """
    
    def __init__(self, image_size=64, max_lattices=8, num_lattices_range=(2, 6)):
        self.image_size = image_size
        self.max_lattices = max_lattices
        self.num_lattices_range = num_lattices_range
    
    def generate_batch(self, batch_size):
        """Generate a batch of training data"""
        batch_images = []
        batch_params = []
        
        for _ in range(batch_size):
            # Generate random parameters
            lattice_params = generate_random_lattice_params(
                self.num_lattices_range, self.image_size
            )
            flat_params = lattice_params.flatten()
            
            # Render image
            tf_params = tf.constant([flat_params], dtype=tf.float32)
            rendered = hole_based_renderer(tf_params, image_size=self.image_size)
            image = rendered[0].numpy()
            
            # Add noise for robustness
            noise_level = np.random.uniform(0, 0.05)
            image += np.random.normal(0, noise_level, image.shape)
            image = np.clip(image, 0, 1)
            
            batch_images.append(image)
            batch_params.append(flat_params)
        
        return np.array(batch_images), np.array(batch_params)
    
    def create_tf_dataset(self, batch_size=32, steps_per_epoch=100):
        """Create TensorFlow dataset with on-the-fly generation"""
        
        def data_generator():
            while True:  # Infinite generator
                images, params = self.generate_batch(batch_size)
                for i in range(batch_size):
                    yield images[i], params[i]
        
        dataset = tf.data.Dataset.from_generator(
            data_generator,
            output_signature=(
                tf.TensorSpec(shape=(self.image_size, self.image_size, 1), dtype=tf.float32),
                tf.TensorSpec(shape=(self.max_lattices * 5,), dtype=tf.float32)
            )
        )
        
        return dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

def run_onthefly_pipeline():
    """
    Alternative pipeline using on-the-fly data generation
    Faster to start, uses less disk space, infinite data variation
    """
    print("üöÄ On-the-Fly Moir√© Lattice ML Pipeline")
    print("=" * 50)
    
    # Create data generator
    data_gen = OnTheFlyMoireDataGenerator(image_size=64, max_lattices=8)
    
    # Create datasets
    print("üìä Creating on-the-fly datasets...")
    train_dataset = data_gen.create_tf_dataset(batch_size=16, steps_per_epoch=50)
    val_dataset = data_gen.create_tf_dataset(batch_size=16, steps_per_epoch=10)
    
    # Create model
    print("üß† Creating model...")
    model = create_improved_lattice_cnn(max_lattices=8)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        loss='mse',  # Simpler loss for faster training
        metrics=['mae']
    )
    
    # Train model
    print("üéØ Training model...")
    history = model.fit(
        train_dataset,
        steps_per_epoch=50,  # Number of batches per epoch
        validation_data=val_dataset,
        validation_steps=10,  # Number of validation batches
        epochs=10,
        verbose=1
    )
    
    # Test model
    print("üìà Testing model...")
    test_images, test_params = data_gen.generate_batch(3)
    
    # Predict
    pred_params = model.predict(test_images)
    
    # Visualize results
    fig, axes = plt.subplots(3, 3, figsize=(12, 12))
    
    for i in range(3):
        # Original image
        axes[i, 0].imshow(test_images[i, :, :, 0], cmap='gray')
        axes[i, 0].set_title(f'Original {i+1}')
        axes[i, 0].axis('off')
        
        # Predicted reconstruction
        pred_tf = tf.constant([pred_params[i]], dtype=tf.float32)
        pred_recon = hole_based_renderer(pred_tf, image_size=64)
        axes[i, 1].imshow(pred_recon[0, :, :, 0], cmap='gray')
        axes[i, 1].set_title(f'Predicted {i+1}')
        axes[i, 1].axis('off')
        
        # True reconstruction
        true_tf = tf.constant([test_params[i]], dtype=tf.float32)
        true_recon = hole_based_renderer(true_tf, image_size=64)
        axes[i, 2].imshow(true_recon[0, :, :, 0], cmap='gray')
        axes[i, 2].set_title(f'True {i+1}')
        axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Calculate MSE
    pred_recons = hole_based_renderer(pred_params, image_size=64)
    mse = tf.reduce_mean(tf.square(pred_recons - test_images)).numpy()
    print(f"Test MSE: {mse:.6f}")
    
    return model, history, data_gen

print("\nüí° TWO APPROACHES AVAILABLE:")
print("1. Pre-generated Dataset: run_complete_pipeline()")
print("   ‚úÖ Consistent data, reproducible results")
print("   ‚ùå Requires disk space, limited variation")
print()
print("2. On-the-Fly Generation: run_onthefly_pipeline()")
print("   ‚úÖ No disk space, infinite variation, faster start")
print("   ‚ùå Slightly slower training, less reproducible")
print()
print("For quick testing, recommend: run_onthefly_pipeline()")

In [None]:
# Final Summary and Status
print("üéâ MOIR√â LATTICE RECONSTRUCTION ML PIPELINE - READY!")
print("=" * 60)
print()
print("‚úÖ Environment setup complete")
print("‚úÖ Hole-based differentiable renderer implemented")
print("‚úÖ Training data generation pipeline ready")
print("‚úÖ HDF5 storage system configured")
print("‚úÖ Improved CNN architecture with residual connections")
print("‚úÖ Custom reconstruction loss function")
print("‚úÖ Training pipeline with callbacks and monitoring")
print("‚úÖ Evaluation metrics and visualization")
print("‚úÖ FFT baseline for comparison")
print("‚úÖ Complete workflow automation")
print()
print("üöÄ Ready for moir√© lattice reconstruction research!")
print()
print("Next: Run the pipeline and analyze results!")
print("Execute: history, model, loader = run_complete_pipeline()")