Lab Notebook: Hyperparameter tuning Diffusion Models for Microstructure Generation


This notebook implements a hyperparameter tuning for diffusion model for generating synthetic microstructures. 
The key components include:

Data Loading: Custom dataset loader for microstructure data from HDF5 files
Model Architecture: UNet with self-attention and time conditioning
Diffusion Process: Forward and reverse diffusion with DDIM sampling
Evaluation: Domain-specific metrics for microstructure characterization
Hyperparameter Tuning: Grid search and random search implementations

Background
Diffusion models have emerged as powerful generative models capable of producing high-quality, diverse samples. They work by gradually adding noise to data in a forward process, then learning to reverse this process to generate new samples. For microstructure generation, this approach is particularly promising as it can capture complex spatial correlations and phase distributions that are critical for material properties.
The implementation focuses on nickel-based superalloys (NBSA), which have a characteristic two-phase microstructure that significantly influences their mechanical properties.
Data Preparation
The dataset consists of 2D microstructure images stored in an HDF5 file format. The custom Micro2DKeyDataset class handles:

Loading specific microstructure classes from the HDF5 file
Normalizing pixel values to the [-1, 1] range (standard for diffusion models)
Optional data augmentation through random flips and affine transformations

Data augmentation is used to increase diversity in the training set and improve model generalization.
Model Architecture
The model is built on a U-Net architecture, which is well-suited for image-to-image tasks:
Key Components:

Time Embedding: Sinusoidal position embeddings convert diffusion timesteps into high-dimensional vectors that condition the model.
Down/Up Blocks: Specialized convolutional blocks that:

Process spatial features at multiple resolutions
Incorporate time information at each layer
Use skip connections to preserve spatial details


Self-Attention: Implemented at the middle of the U-Net to capture long-range dependencies in the microstructure, which is crucial for maintaining phase connectivity and morphology.
Batch Normalization: Improves training stability and convergence.

Diffusion Process
The diffusion process is implemented in the DiffusionModel class:
Forward Process

Gradually adds Gaussian noise to images according to a predefined schedule
Uses a linear beta schedule that controls the noise level at each timestep

Reverse Process

The U-Net model is trained to predict the noise added at each timestep
During training, random timesteps are sampled to teach the model to denoise from any point in the process

DDIM Sampling

Implements Denoising Diffusion Implicit Models (DDIM) sampling for faster generation
Uses a controllable stochasticity parameter (eta) to balance between deterministic and stochastic sampling
Allows for generating high-quality samples with fewer steps than the original training process

Evaluation Metrics
Specialized evaluation metrics are implemented to assess the quality of generated microstructures:

Structural Similarity Index (SSIM): Measures visual similarity between generated and real microstructures
Two-Point Correlation Function (S2):

Quantifies spatial correlations in the microstructure
Implemented using FFT for computational efficiency
Computes a normalized discrepancy between original and generated functions

Lineal Path Function:

Measures the probability of finding continuous line segments within a single phase
Important for capturing phase connectivity that affects material properties
Implemented using an efficient sampling approach

These metrics provide a comprehensive assessment beyond visual similarity, focusing on structural characteristics relevant to material properties.
Hyperparameter Tuning
Two search approaches are implemented:
Grid Search

Systematically evaluates all combinations of hyperparameters
Provides comprehensive coverage of the parameter space
Computationally expensive for large parameter spaces

Random Search

Samples random configurations from the parameter space
More efficient for high-dimensional spaces
Handles different parameter types:

Log-uniform sampling for learning rates
Uniform sampling for continuous parameters
Random selection from lists for categorical parameters

Both methods include:

Early stopping to prevent overfitting
Automatic saving of models, configurations, and results
Comprehensive logging and visualization of results

Experimental Results
For each hyperparameter configuration, the system:

Trains the model for a specified number of epochs (with early stopping)
Evaluates the model on validation data
Generates sample microstructures
Computes evaluation metrics
Creates visualizations of training loss curves and generated samples

Results are organized in a hierarchical directory structure with:

JSON files containing configuration details and metrics
PNG images of loss curves and generated samples
Summary visualizations showing parameter impact on performance
A ranking of the top parameter combinations

Visualization Tools
The implementation includes visualization functions for:

Displaying batches of microstructures in a grid
Plotting loss curves during training
Creating parameter impact plots (validation loss vs. parameter value)
Visualizing the correlation between different metrics

In [None]:
"""
Hyperparameter tuning for diffusion models.
"""
#Import the required libraries 
import os
import numpy as np
import h5py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt
from tqdm.auto import tqdm  # Progress bar utility
from einops import rearrange  # Tensor reshaping utility
from skimage.metrics import structural_similarity as ssim  # Image similarity metric
import torchvision.transforms as transforms
from datetime import datetime
import json
from collections import defaultdict
import itertools

In [None]:
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Dataset Class                          
class Micro2DKeyDataset(Dataset):
    """
    Dataset class for loading 2D microstructure data from HDF5 files.
    
    Args:
        file_path (str): Path to the HDF5 file
        key (str): Key to access specific microstructure class in the HDF5 file
        transform (callable, optional): Optional transform to be applied on a sample
    """
    def __init__(self, file_path, key, transform=None):
        self.file_path = file_path
        self.key = key
        self.transform = transform
        
        # Load data from HDF5 file
        with h5py.File(self.file_path, 'r') as f:
            self.data = f[key][key][:]
        self.data = self.data.astype(np.float32)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        # Get image and normalize to [0, 1] range
        img = self.data[idx]
        img = (img - img.min()) / (img.max() - img.min() + 1e-8)  # Add small epsilon to avoid division by zero
        
        # Convert to PyTorch tensor and add channel dimension
        img = torch.tensor(img).unsqueeze(0)
        
        # Scale to [-1, 1] range for diffusion model
        img = 2 * img - 1
        img = torch.clamp(img, -1.0, 1.0)  # Ensure values are strictly within range
        
        # Apply transforms if specified
        if self.transform:
            img = self.transform(img)
            
        return img

In [None]:
# UNet Model Components                  
class SinusoidalPositionEmbeddings(nn.Module):
    """
    Sinusoidal position embeddings for timestep encoding.
    
    Args:
        dim (int): Dimension of the embedding
    """
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        """
        Convert timesteps to sinusoidal embeddings.
        
        Args:
            time (tensor): Tensor of timesteps [batch_size]
            
        Returns:
            tensor: Sinusoidal embeddings [batch_size, dim]
        """
        device = time.device
        half_dim = self.dim // 2
        embeddings = np.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class Block(nn.Module):
    """
    Basic UNet block with timestep conditioning.
    
    Args:
        in_ch (int): Number of input channels
        out_ch (int): Number of output channels
        time_emb_dim (int): Dimension of time embedding
        up (bool): Whether this is an upsampling block
    """
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)
        self.up = up
        
        if up:
            # For upsampling blocks
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)  # Transposed conv for upsampling
        else:
            # For downsampling blocks
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)  # Stride 2 for downsampling
            
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU()
        
    def forward(self, x, t):
        """
        Forward pass through the block.
        
        Args:
            x (tensor): Input tensor [batch_size, in_ch, height, width]
            t (tensor): Time embedding [batch_size, time_emb_dim]
            
        Returns:
            tensor: Output tensor after transformation
        """
        # First Conv + BatchNorm + ReLU
        h = self.bnorm1(self.relu(self.conv1(x)))
        
        # Time embedding conditioning
        time_emb = self.relu(self.time_mlp(t))
        
        # Extend time embedding dimensions to match spatial dimensions
        time_emb = time_emb[(..., ) + (None, ) * 2]  # Add two dimensions
        
        # Add time channel to feature map
        h = h + time_emb
        
        # Second Conv + BatchNorm + ReLU
        h = self.bnorm2(self.relu(self.conv2(h)))
        
        # Down or Upsample using transform layer
        return self.transform(h)

class SelfAttention(nn.Module):
    """
    Self-attention module for capturing long-range dependencies.
    
    Args:
        channels (int): Number of input channels
    """
    def __init__(self, channels):
        super(SelfAttention, self).__init__()
        self.channels = channels        
        self.mha = nn.MultiheadAttention(channels, 1, batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels),
        )

    def forward(self, x):
        """
        Forward pass for self-attention module.
        
        Args:
            x (tensor): Input tensor [batch_size, channels, height, width]
            
        Returns:
            tensor: Self-attention output with same shape
        """
        # Store original spatial dimensions
        size = x.shape[-2:]
        
        # Rearrange tensor for attention operation:
        # [batch, channels, height, width] -> [batch, height*width, channels]
        x = rearrange(x, 'b c h w -> b (h w) c')
        
        # Layer normalization
        x_ln = self.ln(x)
        
        # Multi-head self-attention
        attention_value, _ = self.mha(x_ln, x_ln, x_ln)
        
        # Residual connection
        attention_value = attention_value + x
        
        # Feed-forward network with residual connection
        attention_value = self.ff_self(attention_value) + attention_value
        
        # Rearrange back to original spatial dimensions:
        # [batch, height*width, channels] -> [batch, channels, height, width]
        return rearrange(attention_value, 'b (h w) c -> b c h w', h=size[0], w=size[1])
class UNet(nn.Module):
    """
    UNet architecture for diffusion model with time conditioning and attention mechanism.
    
    Args:
        in_channels (int): Number of input channels (default: 1 for grayscale)
        out_channels (int): Number of output channels (default: 1)
        time_emb_dim (int): Dimension of time embedding (default: 128)
    """
    def __init__(self, in_channels=1, out_channels=1, time_emb_dim=128):
        super().__init__()
        
        # For debugging
        self.debug = False
        
        # Time embedding
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU()
        )
        
        # Initial projection
        self.conv0 = nn.Conv2d(in_channels, 128, kernel_size=3, padding=1)
        
        # Downsample blocks
        self.down1 = Block(128, 128, time_emb_dim)   # 256x256 -> 128x128
        self.down2 = Block(128, 256, time_emb_dim)   # 128x128 -> 64x64
        self.down3 = Block(256, 256, time_emb_dim)   # 64x64 -> 32x32
        self.down4 = Block(256, 512, time_emb_dim)   # 32x32 -> 16x16
        self.down5 = Block(512, 512, time_emb_dim)   # 16x16 -> 8x8
        
        # Attention layer at the bottleneck (middle of U-Net)
        self.attention = SelfAttention(256)
        
        # Upsampling blocks with skip connections
        self.up1 = Block(512, 512, time_emb_dim, up=True)                # 8x8 -> 16x16
        self.up2 = Block(512 + 512, 256, time_emb_dim, up=True)          # 16x16 -> 32x32
        self.up3 = Block(256 + 256, 256, time_emb_dim, up=True)          # 32x32 -> 64x64
        self.up4 = Block(256 + 256, 128, time_emb_dim, up=True)          # 64x64 -> 128x128
        self.up5 = Block(128 + 128, 128, time_emb_dim, up=True)          # 128x128 -> 256x256
        
        # Final output projection
        self.output = nn.Conv2d(128, out_channels, kernel_size=1)
        
    def forward(self, x, timestep):
        """
        Forward pass through UNet.
        
        Args:
            x (tensor): Input noisy image [batch_size, in_channels, height, width]
            timestep (tensor): Diffusion timesteps [batch_size]
            
        Returns:
            tensor: Predicted noise [batch_size, out_channels, height, width]
        """
        # Get time embedding
        t = self.time_mlp(timestep)
        
        # Initial conv
        x0 = self.conv0(x)  # 256x256
        if self.debug:
            print(f"x0 (initial conv): {x0.shape}")
        
        # Downsample path (encoder)
        x1 = self.down1(x0, t)  # 128x128
        if self.debug:
            print(f"x1 (down1): {x1.shape}")
            
        x2 = self.down2(x1, t)  # 64x64
        if self.debug:
            print(f"x2 (down2): {x2.shape}")
            
        x3 = self.down3(x2, t)  # 32x32
        x3 = self.attention(x3)  # Apply attention at bottleneck
        if self.debug:
            print(f"x3 (down3 + attention): {x3.shape}")
            
        x4 = self.down4(x3, t)  # 16x16
        if self.debug:
            print(f"x4 (down4): {x4.shape}")
            
        x5 = self.down5(x4, t)  # 8x8 (bottleneck)
        if self.debug:
            print(f"x5 (down5): {x5.shape}")
        
        # Upsample path (decoder) with skip connections
        u1 = self.up1(x5, t)  # 16x16
        if self.debug:
            print(f"u1 (up1): {u1.shape}")
            print(f"x4 for concat: {x4.shape}")
            
        # Concatenate with skip connections
        u2 = self.up2(torch.cat([u1, x4], dim=1), t)  # 32x32
        if self.debug:
            print(f"u2 (up2): {u2.shape}")
            print(f"x3 for concat: {x3.shape}")
            
        u3 = self.up3(torch.cat([u2, x3], dim=1), t)  # 64x64
        if self.debug:
            print(f"u3 (up3): {u3.shape}")
            print(f"x2 for concat: {x2.shape}")
            
        u4 = self.up4(torch.cat([u3, x2], dim=1), t)  # 128x128
        if self.debug:
            print(f"u4 (up4): {u4.shape}")
            print(f"x1 for concat: {x1.shape}")
            
        u5 = self.up5(torch.cat([u4, x1], dim=1), t)  # 256x256
        if self.debug:
            print(f"u5 (up5): {u5.shape}")
        
        # Final projection to output channels
        output = self.output(u5)
        if self.debug:
            print(f"output: {output.shape}")
            
        return output

In [None]:
# Diffusion Model Implementation         
def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    """
    Linear beta schedule for diffusion process.
    
    Args:
        timesteps (int): Number of diffusion timesteps
        start (float): Starting beta value
        end (float): Ending beta value
        
    Returns:
        tensor: Beta values schedule
    """
    return torch.linspace(start, end, timesteps)

class DiffusionModel:
    """
    Diffusion model implementation with forward and reverse processes.
    
    Args:
        timesteps (int): Number of diffusion steps (default: 1000)
        beta_schedule (str): Type of beta schedule (default: 'linear')
        device (str): Device to run the model on (default: 'cuda')
    """
    def __init__(self, timesteps=1000, beta_schedule='linear', device='cuda'):
        self.timesteps = timesteps
        self.device = device
        
        # Define noise schedule
        self.betas = linear_beta_schedule(timesteps).to(device)
        
        # Define alphas (1 - beta)
        self.alphas = 1. - self.betas
        
        # Cumprod of alphas for q(x_t | x_0) calculations
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        
        # Previous alpha cumulative product for posterior variance calculation
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        
        # Pre-compute values for diffusion process and sampling
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
        self.log_one_minus_alphas_cumprod = torch.log(1. - self.alphas_cumprod)
        self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / self.alphas_cumprod)
        self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / self.alphas_cumprod - 1)
        
        # Posterior variance calculation for q(x_{t-1} | x_t, x_0)
        self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
        
    def q_sample(self, x_0, t, noise=None):
        """
        Forward diffusion process: Add noise to the original data.
        
        Args:
            x_0 (tensor): Original clean image [batch_size, channels, height, width]
            t (tensor): Timesteps [batch_size]
            noise (tensor, optional): Noise to add, if None random noise is generated
            
        Returns:
            tensor: Noisy image at timestep t
        """
        if noise is None:
            noise = torch.randn_like(x_0)
        
        # Reshape coefficients for broadcasting
        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].reshape(-1, 1, 1, 1)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].reshape(-1, 1, 1, 1)
        
        # Apply forward diffusion equation: x_t = sqrt(α_t)x_0 + sqrt(1-α_t)ε
        return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise
    
    def p_losses(self, denoise_model, x_0, t, noise=None):
        """
        Calculate loss for training the denoising model.
        
        Args:
            denoise_model (nn.Module): The UNet model to predict noise
            x_0 (tensor): Original clean images [batch_size, channels, height, width]
            t (tensor): Timesteps [batch_size]
            noise (tensor, optional): Noise to add, if None random noise is generated
            
        Returns:
            tensor: MSE loss between predicted and actual noise
        """
        if noise is None:
            noise = torch.randn_like(x_0)
        
        # Add noise to input
        x_noisy = self.q_sample(x_0, t, noise=noise)
        
        # Model predicts the noise added
        predicted_noise = denoise_model(x_noisy, t)
        
        # Loss is MSE between predicted and actual noise
        loss = F.mse_loss(predicted_noise, noise)
        
        return loss
    
    def ddim_sample(self, model, shape, n_steps=150, eta=0.0):
        """
        DDIM sampling for accelerated image generation.
        
        Args:
            model (nn.Module): Trained UNet model
            shape (tuple): Shape of samples to generate [batch_size, channels, height, width]
            n_steps (int): Number of sampling steps (less than self.timesteps)
            eta (float): Parameter controlling stochasticity (0 = deterministic, 1 = DDPM)
            
        Returns:
            tensor: Generated samples
        """
        device = next(model.parameters()).device
        b = shape[0]
        
        # Start from pure noise
        img = torch.randn(shape, device=device)
        
        # Select subset of timesteps for DDIM sampling
        timesteps = np.linspace(0, self.timesteps - 1, n_steps, dtype=int)[::-1]
        
        for i in tqdm(range(len(timesteps) - 1), desc='DDIM Sampling'):
            t_current = torch.full((b,), timesteps[i], device=device, dtype=torch.long)
            t_next = torch.full((b,), timesteps[i + 1], device=device, dtype=torch.long)
            
            # Predict noise
            with torch.no_grad():
                predicted_noise = model(img, t_current)
            
            # Extract x0 from xt using the predicted noise
            alpha_cumprod_t = self.alphas_cumprod[t_current].reshape(-1, 1, 1, 1)
            alpha_cumprod_next = self.alphas_cumprod[t_next].reshape(-1, 1, 1, 1)
            
            sqrt_alpha_cumprod_t = torch.sqrt(alpha_cumprod_t)
            sqrt_one_minus_alpha_cumprod_t = torch.sqrt(1 - alpha_cumprod_t)
            
            # Predict x0 from xt and predicted noise
            predicted_x0 = (img - sqrt_one_minus_alpha_cumprod_t * predicted_noise) / sqrt_alpha_cumprod_t
            
            # Calculate the coefficient for predicted_x0
            sqrt_alpha_cumprod_next = torch.sqrt(alpha_cumprod_next)
            
            # Calculate the coefficient for direction pointing to xt
            # Interpolate between DDPM and DDIM using eta
            sigma_t = eta * torch.sqrt((1 - alpha_cumprod_next) / (1 - alpha_cumprod_t) * (1 - alpha_cumprod_t / alpha_cumprod_next))
            
            # Get noise for stochastic part (if eta > 0)
            noise = torch.randn_like(img) if eta > 0 else 0
            
            # Compute the next sample using DDIM formula
            img = sqrt_alpha_cumprod_next * predicted_x0 + \
                  torch.sqrt(1 - alpha_cumprod_next - sigma_t**2) * predicted_noise + \
                  sigma_t * noise
            
        return img

In [None]:
# Evaluation Metrics                     
def two_point_correlation(image, max_distance=None):
    """
    Compute two-point correlation function for a binary image.
    This is a common metric for microstructure characterization.
    
    Args:
        image (tensor or array): Binary image
        max_distance (int, optional): Maximum distance to compute correlation
        
    Returns:
        tuple: (distances, correlation values)
    """
    # Ensure binary image and convert to numpy
    if isinstance(image, torch.Tensor):
        image = image.cpu().numpy()
    
    image = image.squeeze()
    if image.max() > 1:
        image = image / 255.0
    image = (image > 0.5).astype(np.float32)
    
    h, w = image.shape
    if max_distance is None:
        max_distance = min(h, w) // 2
    
    # Calculate autocorrelation using FFT for efficiency
    ft = np.fft.fft2(image)
    power_spectrum = np.abs(np.fft.ifft2(ft * np.conj(ft)))**2
    center = np.fft.fftshift(power_spectrum)[h//2, w//2]
    
    # Compute radial average (azimuthal integration)
    y, x = np.indices((h, w))
    r = np.sqrt((x - w//2)**2 + (y - h//2)**2).astype(np.int32)
    r_max = min(h//2, w//2, max_distance)
    
    tbin = np.bincount(r.ravel(), power_spectrum.ravel())
    nr = np.bincount(r.ravel())
    radial_profile = tbin / nr
    
    # Normalize by central value
    radial_profile = radial_profile[:r_max] / center
    
    return np.arange(r_max), radial_profile

def lineal_path_function(image, max_distance=None):
    """
    Compute lineal path function for a binary image.
    This measures the probability of having a continuous line segment
    completely inside one phase of the microstructure.
    
    Args:
        image (tensor or array): Binary image
        max_distance (int, optional): Maximum distance to compute lineal path
        
    Returns:
        tuple: (distances, lineal path values)
    """
    # Ensure binary image and convert to numpy
    if isinstance(image, torch.Tensor):
        image = image.cpu().numpy()
    
    image = image.squeeze()
    if image.max() > 1:
        image = image / 255.0
    image = (image > 0.5).astype(np.float32)
    
    h, w = image.shape
    if max_distance is None:
        max_distance = min(h, w) // 4  # Limiting to 1/4 for computational efficiency
    
    # Initialize storage for lineal path function
    lp = np.zeros(max_distance)
    counts = np.zeros(max_distance)
    
    # Sample a subset of points for efficiency
    n_samples = 1000
    y_samples = np.random.randint(0, h, n_samples)
    x_samples = np.random.randint(0, w, n_samples)
    
    # Check horizontal lines
    for i, x in enumerate(x_samples):
        y = y_samples[i]
        if x + max_distance <= w:
            line = image[y, x:x+max_distance]
            for l in range(1, max_distance):
                if np.all(line[:l+1] > 0.5):
                    lp[l] += 1
                counts[l] += 1
    
    # Check vertical lines
    for i, y in enumerate(y_samples):
        x = x_samples[i]
        if y + max_distance <= h:
            line = image[y:y+max_distance, x]
            for l in range(1, max_distance):
                if np.all(line[:l+1] > 0.5):
                    lp[l] += 1
                counts[l] += 1
    
    # Normalize by counts
    lp = np.divide(lp, counts, out=np.zeros_like(lp), where=counts!=0)
    
    return np.arange(max_distance), lp

def evaluate_microstructures(original_samples, generated_samples, max_distance=50):
    """
    Evaluate generated microstructures against original ones using
    multiple metrics: SSIM, two-point correlation, and lineal path function.
    
    Args:
        original_samples (list or tensor): Original microstructure samples
        generated_samples (list or tensor): Generated microstructure samples
        max_distance (int): Maximum distance for correlation metrics
        
    Returns:
        dict: Dictionary containing evaluation metrics
    """
    metrics = {
        'ssim_values': [],
        's2_discrepancy': [],
        'lineal_discrepancy': []
    }
    
    for i in range(len(original_samples)):
        # Convert to numpy arrays in [0, 1] range
        if isinstance(original_samples[i], torch.Tensor):
            orig = ((original_samples[i] + 1) / 2).squeeze().cpu().numpy()
        else:
            orig = original_samples[i].squeeze()
            
        if isinstance(generated_samples[i], torch.Tensor):
            gen = ((generated_samples[i] + 1) / 2).squeeze().cpu().numpy()
        else:
            gen = generated_samples[i].squeeze()
        
        # Compute SSIM (Structural Similarity Index)
        ssim_val = ssim(orig, gen, data_range=1.0)
        metrics['ssim_values'].append(ssim_val)
        
        # Compute two-point correlation function (S2)
        r_orig, s2_orig = two_point_correlation(orig, max_distance)
        r_gen, s2_gen = two_point_correlation(gen, max_distance)
        
        # Compute normalized area difference between curves
        area_orig = np.trapz(s2_orig, r_orig)
        area_between = np.trapz(np.abs(s2_orig - s2_gen), r_orig)
        if area_orig > 0:
            s2_discrepancy = (area_between / area_orig) * 100
        else:
            s2_discrepancy = 0
        metrics['s2_discrepancy'].append(s2_discrepancy)
        
        # Compute lineal path function
        r_orig, lp_orig = lineal_path_function(orig, max_distance)
        r_gen, lp_gen = lineal_path_function(gen, max_distance)
        
        # Compute normalized area difference between curves
        area_orig = np.trapz(lp_orig, r_orig)
        area_between = np.trapz(np.abs(lp_orig - lp_gen), r_orig)
        if area_orig > 0:
            lineal_discrepancy = (area_between / area_orig) * 100
        else:
            lineal_discrepancy = 0
        metrics['lineal_discrepancy'].append(lineal_discrepancy)
    
    # Compute averages for summary metrics
    metrics['avg_ssim'] = np.mean(metrics['ssim_values'])
    metrics['avg_s2_discrepancy'] = np.mean(metrics['s2_discrepancy'])
    metrics['avg_lineal_discrepancy'] = np.mean(metrics['lineal_discrepancy'])
    
    return metrics

In [None]:
# Visualization Functions                
def show_microstructure_batch(data, n=4, title=None, save_path=None):
    """
    Display a batch of microstructures in a grid.
    
    Args:
        data (tensor or list): Batch of microstructure images
        n (int): Number of images to display
        title (str, optional): Title for the figure
        save_path (str, optional): Path to save the figure
    """
    plt.figure(figsize=(10, 10))
    for i in range(min(n, len(data))):
        # Get image and convert to numpy
        if isinstance(data, torch.Tensor):
            image = data[i].squeeze().cpu().detach().numpy()
        else:
            image = data[i].squeeze()
            
        # Convert from [-1, 1] to [0, 1] if needed
        if image.min() < 0:
            image = (image + 1) / 2
            
        plt.subplot(2, 2, i+1)
        plt.imshow(image, cmap='gray')
        plt.axis('off')
        
    if title:
        plt.suptitle(title)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path)
        plt.close()
    else:
        plt.show()

In [None]:
# Hyperparameter Tuning Functions        
# Convert NumPy types to standard Python types for JSON serialization
def convert_to_serializable(obj):
    """
    Convert NumPy types to standard Python types for JSON serialization.
    
    Args:
        obj: Object to convert (can be NumPy scalar, array, list, or dict)
        
    Returns:
        Object with NumPy types converted to standard Python types
    """
    if isinstance(obj, (np.int_, np.intc, np.intp, np.int8, np.int16, np.int32, np.int64,
                       np.uint8, np.uint16, np.uint32, np.uint64)):
        return int(obj)
    elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
        return float(obj)
    elif isinstance(obj, (np.ndarray,)):
        return obj.tolist()
    elif isinstance(obj, list):
        return [convert_to_serializable(item) for item in obj]
    elif isinstance(obj, dict):
        return {key: convert_to_serializable(value) for key, value in obj.items()}
    return obj
def simple_grid_search(param_grid, base_config, num_epochs=30):
    """
    Simple grid search implementation for diffusion models.
    
    Args:
        param_grid: Dictionary with parameter names as keys and lists of values to try
        base_config: Base configuration with fixed parameters
        num_epochs: Number of epochs to train each configuration
        
    Returns:
        Dictionary with best parameters and all results
    """
    # Create timestamp for saving results
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    save_dir = os.path.join(base_config.get('save_dir', './'), f"grid_search_{timestamp}")
    os.makedirs(save_dir, exist_ok=True)
    
    # Convert param_grid values to standard Python types for JSON serialization
    serializable_param_grid = convert_to_serializable(param_grid)
    
    # Log the parameter grid
    with open(os.path.join(save_dir, 'param_grid.json'), 'w') as f:
        json.dump(serializable_param_grid, f, indent=4)
    
    # Create all parameter combinations
    keys = list(param_grid.keys())
    values = list(param_grid.values())
    param_combinations = list(itertools.product(*values))
    
    # Create the dataset
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomAffine(degrees=10, scale=(0.9, 1.1)),
    ]) if base_config.get('use_augmentation', True) else None
    
    dataset = Micro2DKeyDataset(
        file_path=base_config['file_path'],
        key=base_config['microstructure_class'],
        transform=transform
    )
    
    # Create a validation split
    train_size = int(0.9 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(
        dataset, [train_size, val_size], 
        generator=torch.Generator().manual_seed(42)
    )
    
    # Track results
    all_results = []
    best_val_loss = float('inf')
    best_params = None
    
    # Test each parameter combination
    print(f"Testing {len(param_combinations)} parameter combinations")
    
    for i, combination in enumerate(param_combinations):
        # Create parameter dictionary
        params = dict(zip(keys, combination))
        run_config = base_config.copy()
        run_config.update(params)
        
        print(f"\nTesting combination {i+1}/{len(param_combinations)}:")
        for key, value in params.items():
            print(f"  {key}: {value}")
        
        # Create run directory
        run_dir = os.path.join(save_dir, f"run_{i+1}")
        os.makedirs(run_dir, exist_ok=True)
        run_config['save_dir'] = run_dir
        
        # Save configuration
        with open(os.path.join(run_dir, 'config.json'), 'w') as f:
            # Convert any non-serializable objects or NumPy types to standard Python types
            serializable_config = convert_to_serializable(run_config)
            json.dump(serializable_config, f, indent=4)
        
        # Create model
        model = UNet(
            in_channels=1, 
            out_channels=1,
            time_emb_dim=params.get('time_emb_dim', 128)
        )
        
        # Create diffusion process
        diffusion = DiffusionModel(
            timesteps=params.get('timesteps', 1000),
            device=device
        )
        
        # Create optimizer
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=params.get('learning_rate', 1e-4),
            weight_decay=params.get('weight_decay', 1e-5)
        )
        
        # Create data loaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=params.get('batch_size', 8),
            shuffle=True,
            num_workers=run_config.get('num_workers', 2),
            pin_memory=True
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=params.get('batch_size', 8),
            shuffle=False,
            num_workers=run_config.get('num_workers', 2),
            pin_memory=True
        )
        
        # Training loop
        model.to(device)
        model.train()
        
        train_losses = []
        val_losses = []
        
        best_epoch_val_loss = float('inf')
        patience = 5  # Number of epochs with no improvement after which training will be stopped
        patience_counter = 0
        
        for epoch in range(num_epochs):
            # Training phase
            model.train()
            train_loss = 0.0
            
            progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
            for batch in progress_bar:
                optimizer.zero_grad()
                
                batch = batch.to(device)
                batch_size = batch.shape[0]
                
                # Sample random timesteps
                t = torch.randint(0, diffusion.timesteps, (batch_size,), device=device).long()
                
                # Calculate loss
                loss = diffusion.p_losses(model, batch, t)
                
                # Backward pass
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
                progress_bar.set_postfix({"loss": loss.item()})
            
            avg_train_loss = train_loss / len(train_loader)
            train_losses.append(avg_train_loss)
            
            # Validation phase
            model.eval()
            val_loss = 0.0
            
            with torch.no_grad():
                for batch in val_loader:
                    batch = batch.to(device)
                    batch_size = batch.shape[0]
                    
                    # Sample random timesteps
                    t = torch.randint(0, diffusion.timesteps, (batch_size,), device=device).long()
                    
                    # Calculate loss
                    loss = diffusion.p_losses(model, batch, t)
                    val_loss += loss.item()
            
            avg_val_loss = val_loss / len(val_loader)
            val_losses.append(avg_val_loss)
            
            print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {avg_val_loss:.6f}")
            
            # Check for early stopping
            if avg_val_loss < best_epoch_val_loss:
                best_epoch_val_loss = avg_val_loss
                patience_counter = 0
                
                # Save best model for this run
                torch.save({
                    'epoch': epoch + 1,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': best_epoch_val_loss
                }, os.path.join(run_dir, "best_model.pt"))
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print(f"Early stopping triggered after {epoch+1} epochs")
                    break
        
        # Plot loss curves
        plt.figure(figsize=(10, 5))
        plt.plot(train_losses, label='Train Loss')
        plt.plot(val_losses, label='Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training and Validation Loss')
        plt.legend()
        plt.savefig(os.path.join(run_dir, 'loss_curves.png'))
        plt.close()
        
        # Generate samples with the best model
        checkpoint = torch.load(os.path.join(run_dir, "best_model.pt"))
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        
        with torch.no_grad():
            # Generate samples using DDIM sampling
            samples = diffusion.ddim_sample(
                model,
                (4, 1, 256, 256),  # Generate 4 samples of 256x256 size
                n_steps=params.get('sample_steps', 150),
                eta=params.get('eta', 0.0)
            )
            
            # Display and save samples
            plt.figure(figsize=(10, 10))
            for i in range(4):
                plt.subplot(2, 2, i+1)
                plt.imshow(((samples[i] + 1) / 2).squeeze().cpu().numpy(), cmap='gray')
                plt.title(f"Sample {i+1}")
                plt.axis('off')
            plt.suptitle(f"Generated Samples")
            plt.tight_layout()
            plt.savefig(os.path.join(run_dir, "samples.png"))
            plt.close()
        
        # Calculate evaluation metrics on validation set
        val_batch = next(iter(val_loader))[:4].to(device)
        metrics = evaluate_microstructures(val_batch, samples)
        
        # Save metrics
        with open(os.path.join(run_dir, 'metrics.json'), 'w') as f:
            # Convert metrics to serializable format
            serializable_metrics = convert_to_serializable(metrics)
            json.dump(serializable_metrics, f, indent=4)
        
        # Record results
        result = {
            'params': convert_to_serializable(params),
            'best_val_loss': best_epoch_val_loss,
            'metrics': {
                'ssim': float(metrics['avg_ssim']),
                's2_discrepancy': float(metrics['avg_s2_discrepancy']),
                'lineal_discrepancy': float(metrics['avg_lineal_discrepancy'])
            }
        }
        all_results.append(result)
        
        # Update best parameters if this run has lower validation loss
        if best_epoch_val_loss < best_val_loss:
            best_val_loss = best_epoch_val_loss
            best_params = params.copy()
            
            # Save best model overall
            checkpoint = torch.load(os.path.join(run_dir, "best_model.pt"))
            torch.save(checkpoint, os.path.join(save_dir, "best_model.pt"))
    
    # Save all results
    with open(os.path.join(save_dir, 'all_results.json'), 'w') as f:
        json.dump(convert_to_serializable(all_results), f, indent=4)
    
    # Save best parameters
    with open(os.path.join(save_dir, 'best_params.json'), 'w') as f:
        json.dump(convert_to_serializable(best_params), f, indent=4)
    
    # Create results summary
    create_results_summary(all_results, save_dir)
    
    return {
        'best_params': best_params,
        'all_results': all_results,
        'save_dir': save_dir
    }

def create_results_summary(all_results, save_dir):
    """
    Create summary visualizations of all hyperparameter search results.
    
    Args:
        all_results (list): List of dictionaries containing results from each run
        save_dir (str): Directory to save the summary visualizations
    """
    # Extract data for plots
    param_names = list(all_results[0]['params'].keys())
    
    # Create summaries for each parameter
    for param_name in param_names:
        values = []
        val_losses = []
        ssim_scores = []
        
        for result in all_results:
            values.append(result['params'][param_name])
            val_losses.append(result['best_val_loss'])
            ssim_scores.append(result['metrics']['ssim'])
        
        # Create plots - convert to common data type
        values = [float(v) if isinstance(v, (int, float)) else str(v) for v in values]
        
        # If we have numeric values, sort by them
        if all(isinstance(v, (int, float)) for v in values):
            # Sort all lists by parameter value
            sorted_indices = np.argsort(values)
            sorted_values = [values[i] for i in sorted_indices]
            sorted_val_losses = [val_losses[i] for i in sorted_indices]
            sorted_ssim_scores = [ssim_scores[i] for i in sorted_indices]
            
            # Plot val loss vs parameter
            plt.figure(figsize=(10, 5))
            plt.plot(sorted_values, sorted_val_losses, 'o-')
            plt.xlabel(param_name)
            plt.ylabel('Validation Loss')
            plt.title(f'Validation Loss vs {param_name}')
            plt.grid(True)
            plt.savefig(os.path.join(save_dir, f'{param_name}_val_loss.png'))
            plt.close()
            
            # Plot SSIM vs parameter
            plt.figure(figsize=(10, 5))
            plt.plot(sorted_values, sorted_ssim_scores, 'o-')
            plt.xlabel(param_name)
            plt.ylabel('SSIM Score')
            plt.title(f'SSIM Score vs {param_name}')
            plt.grid(True)
            plt.savefig(os.path.join(save_dir, f'{param_name}_ssim.png'))
            plt.close()
    
    # Create table of top 5 best results
    sorted_results = sorted(all_results, key=lambda x: x['best_val_loss'])
    top5 = sorted_results[:5]
    
    with open(os.path.join(save_dir, 'top5_results.txt'), 'w') as f:
        f.write("Top 5 Best Parameter Combinations:\n\n")
        for i, result in enumerate(top5):
            f.write(f"Rank {i+1}:\n")
            f.write(f"  Validation Loss: {result['best_val_loss']:.6f}\n")
            f.write(f"  SSIM: {result['metrics']['ssim']:.4f}\n")
            f.write(f"  S2 Discrepancy: {result['metrics']['s2_discrepancy']:.2f}%\n")
            f.write(f"  Lineal Discrepancy: {result['metrics']['lineal_discrepancy']:.2f}%\n")
            f.write("  Parameters:\n")
            for key, value in result['params'].items():
                f.write(f"    {key}: {value}\n")
            f.write("\n")

def random_search(param_ranges, base_config, num_trials=20, num_epochs=30):
    """
    Random search implementation for diffusion models.
    
    Args:
        param_ranges: Dictionary with parameter names as keys and (min, max) or list of values
        base_config: Base configuration with fixed parameters
        num_trials: Number of random configurations to try
        num_epochs: Number of epochs to train each configuration
        
    Returns:
        Dictionary with best parameters and all results
    """
    # Generate random parameter grid
    param_grid = {}
    
    for param_name, param_range in param_ranges.items():
        if isinstance(param_range, tuple) and len(param_range) == 2:
            # Continuous parameter range (min, max)
            min_val, max_val = param_range
            
            # Handle different parameter types
            if param_name == 'learning_rate' or param_name == 'weight_decay':
                # Log-uniform sampling for learning rates
                log_min = np.log(min_val)
                log_max = np.log(max_val)
                values = np.exp(np.random.uniform(log_min, log_max, num_trials))
            elif isinstance(min_val, int) and isinstance(max_val, int):
                # Integer parameter
                values = np.random.randint(min_val, max_val + 1, num_trials)
            else:
                # Uniform sampling for other parameters
                values = np.random.uniform(min_val, max_val, num_trials)
                
        elif isinstance(param_range, list):
            # Discrete choices
            values = [np.random.choice(param_range) for _ in range(num_trials)]
        else:
            raise ValueError(f"Unsupported parameter range type for {param_name}: {type(param_range)}")
            
        param_grid[param_name] = values
    
    # Convert param_grid from trial-based to parameter-based for grid search
    transposed_grid = defaultdict(list)
    
    for trial in range(num_trials):
        for param_name, values in param_grid.items():
            transposed_grid[param_name].append(values[trial])
    
    # Run the search using grid search with the generated parameters
    return simple_grid_search(dict(transposed_grid), base_config, num_epochs)

In [None]:
# Main Execution                         
if __name__ == "__main__":
    # Define parameter grid for grid search
    param_grid = {
        'timesteps': [500, 1000, 1500],          # Number of diffusion steps
        'sample_steps': [100, 150, 200],         # Number of sampling steps (DDIM)
        'learning_rate': [1e-4, 3e-4, 5e-4],     # Learning rate for optimizer
        'batch_size': [4, 8, 16],                # Batch size for training
        'time_emb_dim': [64, 128, 256],          # Time embedding dimension
        'eta': [0.0, 0.2, 0.5]                   # DDIM stochasticity parameter
    }
    
    # Define parameter ranges for random search
    param_ranges = {
        'timesteps': (500, 2000),               # Continuous range of diffusion steps
        'sample_steps': (100, 250),             # Continuous range of sampling steps
        'learning_rate': (1e-5, 5e-4),          # Log-uniform range for learning rate
        'batch_size': [4, 8, 16],               # Discrete options for batch size
        'time_emb_dim': [64, 128, 256],         # Discrete options for embedding dimension
        'weight_decay': (1e-6, 1e-4),           # Log-uniform range for weight decay
        'eta': (0.0, 0.5)                       # Uniform range for DDIM stochasticity
    }
    
    # Base configuration with fixed parameters
    base_config = {
        'file_path': '/lustre/uschill-lab/users/3782/diffusion/New_Diff/MICRO2D_homogenized.h5',
        'microstructure_class': 'NBSA',         # Type of microstructure to model
        'save_dir': f'./tuning_results_{datetime.now().strftime("%Y%m%d_%H%M%S")}',
        'use_augmentation': True,               # Whether to use data augmentation
        'num_workers': 2                        # Number of workers for data loading
    }
    
    # Choose search method
    search_method = "random"  # "grid" or "random"
    
    if search_method == "grid":
        print("Running grid search...")
        results = simple_grid_search(param_grid, base_config, num_epochs=30)
    else:
        print("Running random search...")
        results = random_search(param_ranges, base_config, num_trials=20, num_epochs=30)
    
    print("\nSearch completed!")
    print("Best parameters found:")
    for param, value in results['best_params'].items():
        print(f"{param}: {value}")
    print(f"Results saved to: {results['save_dir']}")