Lab Notebook: Diffusion Models for Microstructure Generation Class - VoidSmall

This notebook implements a diffusion model pipeline for generating synthetic material microstructures. The implementation includes a complete workflow from data loading to evaluation.

Technical Background

Diffusion Models

Diffusion models work through a forward process that gradually adds noise to data, followed by a reverse denoising process that reconstructs the data. The model is trained to predict the noise added at each step, enabling it to generate new samples starting from pure noise.

Key features of the implementation:

Forward diffusion process with linear beta schedule
Reverse denoising using a UNet architecture with self-attention
DDIM (Denoising Diffusion Implicit Models) sampling for faster generation

Microstructure Characterization
Microstructures are characterized using specialized metrics:

Two-point correlation function (S2): Measures spatial correlations
Lineal path function: Captures phase connectivity in the material
Structural Similarity Index (SSIM): Quantifies visual similarity

These metrics provide insights beyond visual comparison and relate directly to material properties.
Implementation Details
Dataset
The Micro2DKeyDataset class handles loading microstructure data from HDF5 files:

Normalizes pixel values to [-1, 1] range
Applies data augmentation including flips and affine transformations
Enforces value clamping to maintain numerical stability

Model Architecture
The UNet architecture includes:

Time embedding via sinusoidal position embeddings
Downsample path with 5 blocks (128 → 512 channels)
Self-attention module at the middle of the network
Upsample path with skip connections
Batch normalization at each layer for training stability

The self-attention mechanism is crucial for capturing long-range dependencies in microstructure patterns.
Diffusion Process
The DiffusionModel class implements:

Forward diffusion with a linear noise schedule
Loss calculation for noise prediction
DDIM sampling with controllable stochasticity parameter (η)

DDIM sampling accelerates generation by reducing the number of sampling steps while maintaining quality.
Training Pipeline
The train_diffusion_model function manages the training loop:

Randomly samples timesteps during training
Calculates loss by comparing predicted and actual noise
Periodically saves model checkpoints and generates samples
Visualizes progress throughout training

Evaluation Metrics

Multiple evaluation approaches are implemented:

Material-specific metrics:

Two-point correlation function discrepancy
Lineal path function discrepancy
SSIM for visual similarity


General-purpose metrics:

Fréchet Inception Distance (FID): Measures distribution similarity
Inception Score (IS): Assesses quality and diversity of generated samples



Experimental Setup

Configuration

The code uses a comprehensive configuration system with the following key parameters:
config = {
    'file_path': Path to HDF5 data file,
    'microstructure_class': 'VoidSmall',  # Target microstructure class
    'batch_size': 16,
    'epochs': 50,
    'learning_rate': 5e-5,
    'weight_decay': 1e-5,
    'timesteps': 2000,  # Number of diffusion steps
    'sample_steps': 150,  # Steps for DDIM sampling
    'use_augmentation': True,
    'num_workers': 2,
    'save_dir': Path for saving results
}
Hardware and Environment

Uses PyTorch with CUDA acceleration when available
Leverages data parallelism through the DataLoader with multiple workers
Employs tqdm for progress tracking during training and sampling

Improvements and Robust Implementation
Several robustness improvements are included:

Error handling in metric calculation:

Safely handles edge cases in two-point correlation function
Uses numerical stability techniques (e.g., epsilon values for division)
Returns fallback values instead of raising exceptions


Inception metrics:

Handles grayscale to RGB conversion for the inception model
Processes images in batches to prevent memory issues
Robust feature extraction for FID calculation


Visualization safety:

Ensures proper normalization for visualization
Handles tensor to numpy conversion
Provides options for both display and saving

Usage Guide
To train a new diffusion model:

Configure the parameters in the config dictionary

Call the run_training function

Examine the results in the specified save directory:

Saved model weights

Generated samples

Evaluation metrics and visualizations

For inference with a trained model:

Load the model state using torch.load

Use the ddim_sample method of the diffusion model to generate samples

Apply the evaluate_microstructures function to assess quality

In [None]:
"""
Diffusion Model for Microstructure Generation - VoidSmall

This implementation creates a complete diffusion model pipeline for generating
synthetic material microstructures. The model uses a UNet with self-attention
to reverse a noise diffusion process, enabling high-quality generation of material
structures with physically relevant properties.

Key components:
1. Custom dataset loading for microstructure data from HDF5 files
2. UNet architecture with timestep conditioning and self-attention
3. Diffusion model with forward and reverse processes
4. Material-specific evaluation metrics (two-point correlation, lineal path)
5. Training and visualization utilities
"""
import os
import numpy as np
import h5py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torchvision import transforms
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from einops import rearrange  # For tensor reshaping operations
from datetime import datetime
from tqdm.auto import tqdm  
from skimage.metrics import structural_similarity as ssim  # For image similarity assessment
import torchvision.models as models  # For Inception model (FID/IS metrics)
from scipy import linalg  # For matrix operations in FID calculation
from torch.nn.functional import adaptive_avg_pool2d
from scipy import integrate  # For numerical integration in evaluation metrics

In [None]:
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Dataset Implementation 
class Micro2DKeyDataset(Dataset):
    """
    Custom dataset for loading 2D microstructure data from HDF5 files.
    
    Args:
        file_path (str): Path to the HDF5 data file
        key (str): Specific microstructure class to load from the file
        transform (callable, optional): Optional transformations to apply
    """
    def __init__(self, file_path, key, transform=None):
        self.file_path = file_path
        self.key = key
        self.transform = transform
        
        # Load data from HDF5 file
        with h5py.File(self.file_path, 'r') as f:
            self.data = f[key][key][:]
        self.data = self.data.astype(np.float32)  # Convert to float32 for PyTorch
        
    def __len__(self):
        """Return the total number of samples in the dataset"""
        return len(self.data)
    
    def __getitem__(self, idx):
        """
        Get a microstructure sample by index
        
        Processing steps:
        1. Normalize the image to [0,1] range
        2. Convert to PyTorch tensor
        3. Scale to [-1,1] range (standard for diffusion models)
        4. Apply any specified transformations
        """
        img = self.data[idx]
        # Normalize to [0,1] range with small epsilon to prevent division by zero
        img = (img - img.min()) / (img.max() - img.min() + 1e-8)
        # Convert to tensor and add channel dimension
        img = torch.tensor(img).unsqueeze(0)
        # Scale to [-1,1] range
        img = 2 * img - 1
        # Ensure values are strictly within range
        img = torch.clamp(img, -1.0, 1.0)
        # Apply transformations if specified
        if self.transform:
            img = self.transform(img)
        return img

In [None]:
# Visualization Functions 
def show_microstructure_batch(data, n=4, title=None, save_path=None):
    """
    Display or save a batch of microstructure images in a grid.
    
    Args:
        data (torch.Tensor or list): Batch of microstructure images
        n (int): Number of images to display (max 4)
        title (str, optional): Title for the figure
        save_path (str, optional): Path to save the figure instead of displaying
    """
    plt.figure(figsize=(10, 10))
    for i in range(min(n, len(data))):
        # Get image and convert to numpy
        if isinstance(data, torch.Tensor):
            image = data[i].squeeze().cpu().detach().numpy()
        else:
            image = data[i].squeeze()
            
        # Convert from [-1, 1] to [0, 1] if needed
        if image.min() < 0:
            image = (image + 1) / 2
            
        plt.subplot(2, 2, i+1)
        plt.imshow(image, cmap='gray')
        plt.axis('off')
        
    if title:
        plt.suptitle(title)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path)
        plt.close()
    else:
        plt.show()

def print_debug(message):
    """Utility function for debug printing with a consistent format"""
    print(f"DEBUG: {message}")

In [None]:
# UNet Model Components 
class SinusoidalPositionEmbeddings(nn.Module):
    """
    Positional embeddings for timestep conditioning in the diffusion model.
    
    Uses sinusoidal embeddings similar to those in the Transformer architecture
    to convert scalar timesteps into high-dimensional feature vectors.
    
    Args:
        dim (int): Dimension of the embedding
    """
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        """
        Convert timesteps to sinusoidal embeddings
        
        Args:
            time (tensor): Tensor of timesteps [batch_size]
            
        Returns:
            tensor: Sinusoidal embeddings [batch_size, dim]
        """
        device = time.device
        half_dim = self.dim // 2
        # Create log-spaced frequency bands
        embeddings = np.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        # Create embeddings through sinusoidal functions
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class Block(nn.Module):
    """
    Basic UNet block with time conditioning for diffusion model.
    
    This block performs either downsampling or upsampling based on the 'up' parameter.
    Each block includes convolutions, batch normalization, and time embedding injection.
    
    Args:
        in_ch (int): Number of input channels
        out_ch (int): Number of output channels
        time_emb_dim (int): Dimension of time embedding
        up (bool): Whether this is an upsampling block (False = downsampling)
    """
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        # Time embedding projection
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)
        self.up = up
        
        if up:
            # For upsampling blocks
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            # Transposed convolution for upsampling (stride 2)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            # For downsampling blocks
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            # Strided convolution for downsampling (stride 2)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
            
        # Second convolution after time conditioning
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        # Batch normalization for training stability
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU()
        
    def forward(self, x, t):
        """
        Forward pass through the block
        
        Args:
            x (tensor): Input feature map [B, in_ch, H, W]
            t (tensor): Time embedding [B, time_emb_dim]
            
        Returns:
            tensor: Output feature map [B, out_ch, H/2, W/2] for down blocks
                   or [B, out_ch, H*2, W*2] for up blocks
        """
        # First Conv + BatchNorm + ReLU
        h = self.bnorm1(self.relu(self.conv1(x)))
        
        # Project time embedding to channel dimension
        time_emb = self.relu(self.time_mlp(t))
        
        # Extend time embedding dimensions to match spatial dimensions (add H, W dims)
        time_emb = time_emb[(..., ) + (None, ) * 2]
        
        # Add time embedding to feature map (additive conditioning)
        h = h + time_emb
        
        # Second Conv + BatchNorm + ReLU
        h = self.bnorm2(self.relu(self.conv2(h)))
        
        # Down or Upsample using transform layer
        return self.transform(h)

class SelfAttention(nn.Module):
    """
    Self-attention module for capturing long-range dependencies in images.
    
    Implements a form of self-attention similar to the one in the Transformer
    architecture, adapted for 2D feature maps.
    
    Args:
        channels (int): Number of input channels
    """
    def __init__(self, channels):
        super(SelfAttention, self).__init__()
        self.channels = channels        
        # Multi-head attention with one head
        self.mha = nn.MultiheadAttention(channels, 1, batch_first=True)
        # Layer normalization
        self.ln = nn.LayerNorm([channels])
        # Feed-forward network after attention
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),  # GELU activation for better gradient properties
            nn.Linear(channels, channels),
        )

    def forward(self, x):
        """
        Apply self-attention to input feature map
        
        Args:
            x (tensor): Input feature map [B, C, H, W]
            
        Returns:
            tensor: Output feature map with same shape but enhanced with attention
        """
        # Store original spatial dimensions
        size = x.shape[-2:]
        
        # Reshape tensor for attention operation: [B, C, H, W] -> [B, H*W, C]
        # This treats each pixel as a "token" with C-dimensional features
        x = rearrange(x, 'b c h w -> b (h w) c')
        
        # Apply layer normalization
        x_ln = self.ln(x)
        
        # Apply multi-head self-attention
        attention_value, _ = self.mha(x_ln, x_ln, x_ln)
        
        # First residual connection
        attention_value = attention_value + x
        
        # Feed-forward network with second residual connection
        attention_value = self.ff_self(attention_value) + attention_value
        
        # Reshape back to original spatial dimensions: [B, H*W, C] -> [B, C, H, W]
        return rearrange(attention_value, 'b (h w) c -> b c h w', h=size[0], w=size[1])

class UNet(nn.Module):
    """
    UNet architecture for diffusion model with time conditioning and self-attention.
    
    The UNet follows a standard encoder-decoder structure with skip connections,
    but includes timestep conditioning and self-attention mechanisms which are
    crucial for diffusion models.
    
    Args:
        in_channels (int): Number of input image channels (default=1 for grayscale)
        out_channels (int): Number of output channels (default=1)
        time_emb_dim (int): Dimension of time embedding (default=128)
    """
    def __init__(self, in_channels=1, out_channels=1, time_emb_dim=128):
        super().__init__()
        
        # For debugging feature map dimensions
        self.debug = False
        
        # Time embedding with sinusoidal positional encoding
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU()
        )
        
        # Initial projection from image to feature map
        self.conv0 = nn.Conv2d(in_channels, 128, kernel_size=3, padding=1)
        
        # Downsample path (encoder)
        self.down1 = Block(128, 128, time_emb_dim)      # 256x256 -> 128x128
        self.down2 = Block(128, 256, time_emb_dim)      # 128x128 -> 64x64
        self.down3 = Block(256, 256, time_emb_dim)      # 64x64 -> 32x32
        self.down4 = Block(256, 512, time_emb_dim)      # 32x32 -> 16x16
        self.down5 = Block(512, 512, time_emb_dim)      # 16x16 -> 8x8
        
        # Self-attention module at the middle of the U-Net
        # This helps capture long-range dependencies, which is crucial for
        # maintaining global structure in the generated microstructures
        self.attention = SelfAttention(256)
        
        # Upsample path (decoder) with skip connections
        self.up1 = Block(512, 512, time_emb_dim, up=True)               # 8x8 -> 16x16
        self.up2 = Block(512 + 512, 256, time_emb_dim, up=True)         # 16x16 -> 32x32
        self.up3 = Block(256 + 256, 256, time_emb_dim, up=True)         # 32x32 -> 64x64
        self.up4 = Block(256 + 256, 128, time_emb_dim, up=True)         # 64x64 -> 128x128
        self.up5 = Block(128 + 128, 128, time_emb_dim, up=True)         # 128x128 -> 256x256
        
        # Final 1x1 convolution to predicted noise
        self.output = nn.Conv2d(128, out_channels, kernel_size=1)
        
    def forward(self, x, timestep):
        """
        Forward pass through UNet with skip connections
        
        Args:
            x (tensor): Input noisy image [B, C, H, W]
            timestep (tensor): Diffusion timesteps [B]
            
        Returns:
            tensor: Predicted noise [B, C, H, W]
        """
        # Get time embedding
        t = self.time_mlp(timestep)
        
        # Initial convolution
        x0 = self.conv0(x)  # 256x256
        if self.debug:
            print(f"x0 (initial conv): {x0.shape}")
        
        # Downsample path (encoder) with debugging prints
        x1 = self.down1(x0, t)  # 128x128
        if self.debug:
            print(f"x1 (down1): {x1.shape}")
            
        x2 = self.down2(x1, t)  # 64x64
        if self.debug:
            print(f"x2 (down2): {x2.shape}")
            
        x3 = self.down3(x2, t)  # 32x32
        # Apply self-attention at the bottleneck of the UNet
        x3 = self.attention(x3)
        if self.debug:
            print(f"x3 (down3 + attention): {x3.shape}")
            
        x4 = self.down4(x3, t)  # 16x16
        if self.debug:
            print(f"x4 (down4): {x4.shape}")
            
        x5 = self.down5(x4, t)  # 8x8 (bottleneck)
        if self.debug:
            print(f"x5 (down5): {x5.shape}")
        
        # Upsample path (decoder) with skip connections
        u1 = self.up1(x5, t)  # 16x16
        if self.debug:
            print(f"u1 (up1): {u1.shape}")
            print(f"x4 for concat: {x4.shape}")
            
        # Concatenate with skip connections
        u2 = self.up2(torch.cat([u1, x4], dim=1), t)  # 32x32
        if self.debug:
            print(f"u2 (up2): {u2.shape}")
            print(f"x3 for concat: {x3.shape}")
            
        u3 = self.up3(torch.cat([u2, x3], dim=1), t)  # 64x64
        if self.debug:
            print(f"u3 (up3): {u3.shape}")
            print(f"x2 for concat: {x2.shape}")
            
        u4 = self.up4(torch.cat([u3, x2], dim=1), t)  # 128x128
        if self.debug:
            print(f"u4 (up4): {u4.shape}")
            print(f"x1 for concat: {x1.shape}")
            
        u5 = self.up5(torch.cat([u4, x1], dim=1), t)  # 256x256
        if self.debug:
            print(f"u5 (up5): {u5.shape}")
        
        # Final 1x1 convolution to predicted noise
        output = self.output(u5)
        if self.debug:
            print(f"output: {output.shape}")
            
        return output

In [None]:
# Diffusion Model Implementation 
def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    """
    Linear noise schedule for diffusion process.
    
    Creates a linear schedule for beta values from start to end.
    Beta controls the noise level added at each timestep.
    
    Args:
        timesteps (int): Number of diffusion timesteps
        start (float): Starting beta value (small)
        end (float): Ending beta value (larger)
        
    Returns:
        torch.Tensor: Linear schedule of beta values
    """
    return torch.linspace(start, end, timesteps)

class DiffusionModel:
    """
    Implements the diffusion process for image generation.
    
    This class handles both the forward process (adding noise) and provides
    methods for the reverse process (removing noise) via sampling.
    
    Args:
        timesteps (int): Number of diffusion steps (default=4000)
        beta_schedule (str): Type of beta schedule (only 'linear' supported)
        device (str): Device to run the model on (default='cuda')
    """
    def __init__(self, timesteps=4000, beta_schedule='linear', device='cuda'):
        self.timesteps = timesteps
        self.device = device
        
        # Define noise schedule (variance of noise added at each step)
        self.betas = linear_beta_schedule(timesteps).to(device)
        
        # Define alphas = 1 - betas (signal proportion kept at each step)
        self.alphas = 1. - self.betas
        
        # Cumulative product of alphas (for computing statistics of q(x_t | x_0))
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        
        # Previous alpha cumulative product (for posterior variance calculation)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        
        # Pre-compute values for diffusion process and sampling
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
        self.log_one_minus_alphas_cumprod = torch.log(1. - self.alphas_cumprod)
        self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / self.alphas_cumprod)
        self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / self.alphas_cumprod - 1)
        
        # Calculations for posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
        
    def q_sample(self, x_0, t, noise=None):
        """
        Forward diffusion process: Add noise to the original data.
        
        Implements q(x_t | x_0) - adding noise over multiple timesteps.
        
        Args:
            x_0 (tensor): Original clean images [B, C, H, W]
            t (tensor): Timesteps [B]
            noise (tensor, optional): Noise to add (random if None)
            
        Returns:
            tensor: Noisy images at timestep t
        """
        if noise is None:
            noise = torch.randn_like(x_0)
        
        # Extract the appropriate alpha and sigma for the given timesteps
        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].reshape(-1, 1, 1, 1)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].reshape(-1, 1, 1, 1)
        
        # Apply forward diffusion formula: x_t = sqrt(α_t)·x_0 + sqrt(1-α_t)·ε
        return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise
    
    def p_losses(self, denoise_model, x_0, t, noise=None):
        """
        Calculate loss for training the denoising model.
        
        The loss is based on predicting the noise that was added at timestep t.
        
        Args:
            denoise_model (nn.Module): UNet model to predict noise
            x_0 (tensor): Original clean images [B, C, H, W]
            t (tensor): Timesteps [B]
            noise (tensor, optional): Noise to add (random if None)
            
        Returns:
            tensor: Mean squared error between predicted and actual noise
        """
        if noise is None:
            noise = torch.randn_like(x_0)
        
        # Create noisy samples at timestep t
        x_noisy = self.q_sample(x_0, t, noise=noise)
        
        # Get model's predicted noise
        predicted_noise = denoise_model(x_noisy, t)
        
        # Loss is MSE between predicted and actual noise
        loss = F.mse_loss(predicted_noise, noise)
        
        return loss
    
    def ddim_sample(self, model, shape, n_steps=150, eta=0.0):
        """
        DDIM sampling for accelerated image generation.
        
        DDIM (Denoising Diffusion Implicit Models) allows for much faster
        sampling than DDPM while maintaining quality. The eta parameter
        controls stochasticity (0 = deterministic, 1 = DDPM-like).
        
        Args:
            model (nn.Module): Trained UNet denoising model
            shape (tuple): Shape of samples to generate [B, C, H, W]
            n_steps (int): Number of sampling steps (fewer than training timesteps)
            eta (float): Controls stochasticity (0 = deterministic, 1 = DDPM-like)
            
        Returns:
            tensor: Generated samples [B, C, H, W]
        """
        device = next(model.parameters()).device
        b = shape[0]
        
        # Start from pure noise
        img = torch.randn(shape, device=device)
        
        # Select subset of timesteps for DDIM sampling (evenly spaced)
        timesteps = np.linspace(0, self.timesteps - 1, n_steps, dtype=int)[::-1]
        
        # Progressively denoise the image
        for i in tqdm(range(len(timesteps) - 1), desc='DDIM Sampling'):
            t_current = torch.full((b,), timesteps[i], device=device, dtype=torch.long)
            t_next = torch.full((b,), timesteps[i + 1], device=device, dtype=torch.long)
            
            # Predict noise
            with torch.no_grad():
                predicted_noise = model(img, t_current)
            
            # Extract x0 from xt using the predicted noise
            alpha_cumprod_t = self.alphas_cumprod[t_current].reshape(-1, 1, 1, 1)
            alpha_cumprod_next = self.alphas_cumprod[t_next].reshape(-1, 1, 1, 1)
            
            sqrt_alpha_cumprod_t = torch.sqrt(alpha_cumprod_t)
            sqrt_one_minus_alpha_cumprod_t = torch.sqrt(1 - alpha_cumprod_t)
            
            # Predict x0 from xt and predicted noise
            predicted_x0 = (img - sqrt_one_minus_alpha_cumprod_t * predicted_noise) / sqrt_alpha_cumprod_t
            
            # Calculate coefficient for predicted_x0
            sqrt_alpha_cumprod_next = torch.sqrt(alpha_cumprod_next)
            
            # Calculate coefficient for direction pointing to xt
            # Interpolate between DDPM and DDIM using eta
            sigma_t = eta * torch.sqrt((1 - alpha_cumprod_next) / (1 - alpha_cumprod_t) * (1 - alpha_cumprod_t / alpha_cumprod_next))
            
            # Get noise for stochastic part (if eta > 0)
            noise = torch.randn_like(img) if eta > 0 else 0
            
            # Compute the next sample using DDIM formula
            img = sqrt_alpha_cumprod_next * predicted_x0 + \
                  torch.sqrt(1 - alpha_cumprod_next - sigma_t**2) * predicted_noise + \
                  sigma_t * noise
            
        return img

In [None]:
# Training Function
def train_diffusion_model(dataloader, model, diffusion, optimizer, config):
    """
    Train the diffusion model.
    
    Args:
        dataloader: DataLoader containing training data
        model: UNet model to train
        diffusion: DiffusionModel instance
        optimizer: Optimizer instance (e.g., Adam)
        config: Dictionary with training configuration
        
    Returns:
        Trained model
    """
    # Move model to appropriate device
    model.to(device)
    model.train()
    
    # Training loop over epochs
    for epoch in range(config['epochs']):
        epoch_loss = 0.0
        progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch+1}/{config['epochs']}")
        
        for step, batch in progress_bar:
            optimizer.zero_grad()
            
            # Move batch to device
            batch = batch.to(device)
            batch_size = batch.shape[0]
            
            # Sample random timesteps for each image in batch
            t = torch.randint(0, diffusion.timesteps, (batch_size,), device=device).long()
            
            # Calculate loss (predict noise added at timestep t)
            loss = diffusion.p_losses(model, batch, t)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            # Track loss
            epoch_loss += loss.item()
            progress_bar.set_postfix({"loss": loss.item()})
            
        # Print average loss for epoch
        avg_loss = epoch_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{config['epochs']}, Average Loss: {avg_loss:.6f}")
        
        # Generate and save samples every 10 epochs
        if (epoch + 1) % 10 == 0:
            # Save model checkpoint
            save_path = os.path.join(config['save_dir'], f"model_epoch_{epoch+1}.pt")
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss
            }, save_path)
            
            # Generate samples
            model.eval()
            with torch.no_grad():
                # Generate 4 sample images using DDIM sampling
                samples = diffusion.ddim_sample(
                    model, 
                    (4, 1, 256, 256), 
                    n_steps=config['sample_steps']
                )
                
                # Save samples as image
                samples_path = os.path.join(config['save_dir'], f"samples_epoch_{epoch+1}.png")
                show_microstructure_batch(samples, n=4, title=f"Epoch {epoch+1} Samples", save_path=samples_path)
                
                # Also display the samples
                plt.figure(figsize=(10, 10))
                for i in range(4):
                    plt.subplot(2, 2, i+1)
                    plt.imshow(((samples[i] + 1) / 2).squeeze().cpu().numpy(), cmap='gray')
                    plt.title(f"Sample {i+1}")
                    plt.axis('off')
                plt.suptitle(f"Generated Samples at Epoch {epoch+1}")
                plt.tight_layout()
                plt.show()
                
            # Return to training mode
            model.train()
    
    # Save final model
    save_path = os.path.join(config['save_dir'], "final_model.pt")
    torch.save({
        'model_state_dict': model.state_dict(),
    }, save_path)
    
    return model

In [None]:
# Evaluation Metrics - FID/IS
class InceptionV3(nn.Module):
    """
    Pretrained InceptionV3 network for FID and IS computation.
    
    This class extracts features from the InceptionV3 model, split into
    specific blocks corresponding to the layers needed for FID and IS metrics.
    
    FID (Fréchet Inception Distance) and IS (Inception Score) are common metrics
    for evaluating the quality of generated images.
    """
    def __init__(self):
        super(InceptionV3, self).__init__()
        # Load pretrained InceptionV3 model
        inception = models.inception_v3(pretrained=True)
        
        # Split the model into blocks for feature extraction
        # Block 1: Initial convolutions and pooling
        self.block1 = nn.Sequential(
            inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3,
            inception.Conv2d_2b_3x3,
            nn.MaxPool2d(kernel_size=3, stride=2)
        )
        
        # Block 2: More convolutions and pooling
        self.block2 = nn.Sequential(
            inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3,
            nn.MaxPool2d(kernel_size=3, stride=2)
        )
        
        # Block 3: Mixed inception modules (5b-6e)
        self.block3 = nn.Sequential(
            inception.Mixed_5b, inception.Mixed_5c, inception.Mixed_5d,
            inception.Mixed_6a, inception.Mixed_6b, inception.Mixed_6c,
            inception.Mixed_6d, inception.Mixed_6e
        )
        
        # Block 4: Final inception modules, pooling and flattening
        # This is the feature extraction layer used for FID
        self.block4 = nn.Sequential(
            inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c,
            adaptive_avg_pool2d, nn.Flatten()
        )
        
    def forward(self, x):
        """
        Extract features for FID calculation
        
        Args:
            x (tensor): Input images [B, C, H, W]
            
        Returns:
            tensor: Features for FID/IS computation [B, 2048]
        """
        # Resize input to 299x299 as required by InceptionV3
        x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
        
        # If grayscale, repeat to 3 channels (InceptionV3 expects RGB)
        if x.shape[1] == 1:
            x = x.repeat(1, 3, 1, 1)
        
        # Forward pass through blocks
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        return x

In [None]:
# FID and IS Calculation Functions
def calculate_activation_statistics(images, model, batch_size=64, device='cuda'):
    """
    Calculate mean and covariance of features for FID calculation.
    
    Args:
        images (tensor): Batch of images [B, C, H, W]
        model (nn.Module): Inception model for feature extraction
        batch_size (int): Batch size for feature extraction
        device (str): Device to run computation on
        
    Returns:
        tuple: (mean, covariance) of the activation statistics
    """
    model.eval()
    
    # Process images in batches to avoid memory issues
    n_batches = len(images) // batch_size + 1
    
    # Initialize storage for activations (2048 is InceptionV3 feature dimension)
    act = np.empty((len(images), 2048))
    
    # Process in batches
    for i in range(n_batches):
        start = i * batch_size
        end = start + batch_size
        
        # Handle last batch that might be smaller
        if end > len(images):
            end = len(images)
        
        # Extract features for current batch
        batch = images[start:end].to(device)
        with torch.no_grad():
            act[start:end] = model(batch).cpu().numpy()
    
    # Calculate statistics
    mu = np.mean(act, axis=0)  # Mean across samples
    sigma = np.cov(act, rowvar=False)  # Covariance matrix
    return mu, sigma

def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """
    Calculate Fréchet Inception Distance between two sets of images.
    
    FID measures the distance between two distributions in feature space.
    Lower values indicate more similar distributions (better generation quality).
    
    Args:
        mu1, mu2 (ndarray): Mean feature vectors for real and generated images
        sigma1, sigma2 (ndarray): Covariance matrices for real and generated images
        eps (float): Small value to avoid numerical issues
        
    Returns:
        float: FID score
    """
    # Calculate squared distance between means
    diff = mu1 - mu2
    
    # Calculate sqrt of product of covariances (may be numerically unstable)
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    
    # Ensure covmean is well-behaved (not containing NaN or Inf)
    if not np.isfinite(covmean).all():
        # Add small offset to diagonal for numerical stability
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
    
    # Handle complex values from sqrt of matrices
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    # Calculate trace of covmean
    tr_covmean = np.trace(covmean)
    
    # FID formula: ||μ_1 - μ_2||^2 + Tr(Σ_1 + Σ_2 - 2√(Σ_1Σ_2))
    return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean

def calculate_inception_score(samples, model, batch_size=64, splits=10, device='cuda'):
    """
    Calculate Inception Score for a batch of generated images.
    
    IS measures both quality and diversity of generated images.
    Higher values indicate better generation quality.
    
    Args:
        samples (tensor): Batch of generated images [B, C, H, W]
        model (nn.Module): Inception model for class probabilities
        batch_size (int): Batch size for processing
        splits (int): Number of splits for calculating statistics
        device (str): Device to run computation on
        
    Returns:
        tuple: (mean, std) of the inception score
    """
    model.eval()
    N = len(samples)
    
    # Storage for class probabilities (1000 ImageNet classes)
    preds = np.zeros((N, 1000))
    
    # Process in batches
    for i in range(0, N, batch_size):
        # Extract current batch
        batch = samples[i:i+batch_size].to(device)
        
        # Convert grayscale to RGB if needed
        if batch.shape[1] == 1:
            batch = batch.repeat(1, 3, 1, 1)
            
        # Resize to InceptionV3 input size
        batch = F.interpolate(batch, size=(299, 299), mode='bilinear', align_corners=False)
        
        # Get softmax predictions (class probabilities)
        with torch.no_grad():
            pred = F.softmax(model(batch), dim=1).cpu().numpy()
            
        preds[i:i+batch.shape[0]] = pred
    
    # Split predictions for calculation
    scores = []
    for i in range(splits):
        # Take subset of predictions for current split
        part = preds[i * (N // splits):(i + 1) * (N // splits), :]
        
        # Calculate KL divergence between probabilities and their mean
        kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
        kl = np.mean(np.sum(kl, 1))
        
        # Inception score is exp of KL divergence
        scores.append(np.exp(kl))
    
    # Return mean and standard deviation across splits
    return np.mean(scores), np.std(scores)

def compute_fid_and_is(real_images, generated_images, device='cuda'):
    """
    Compute FID and IS metrics for real and generated images.
    
    Args:
        real_images (tensor): Batch of real images [B, C, H, W] in range [-1, 1]
        generated_images (tensor): Batch of generated images [B, C, H, W] in range [-1, 1]
        device (str): Device to run models on
    
    Returns:
        dict: Dictionary with FID and IS scores
    """
    # Normalize to [0, 1] range if needed
    if real_images.min() < 0:
        real_images = (real_images + 1) / 2
    if generated_images.min() < 0:
        generated_images = (generated_images + 1) / 2
    
    # Load Inception model
    inception_model = InceptionV3().to(device)
    inception_model.eval()
    
    # Calculate FID
    mu_real, sigma_real = calculate_activation_statistics(real_images, inception_model, device=device)
    mu_gen, sigma_gen = calculate_activation_statistics(generated_images, inception_model, device=device)
    fid_value = calculate_fid(mu_real, sigma_real, mu_gen, sigma_gen)
    
    # Calculate IS for generated images
    is_mean, is_std = calculate_inception_score(generated_images, inception_model, device=device)
    
    return {
        'fid': fid_value,
        'is_mean': is_mean,
        'is_std': is_std
    }

In [None]:
# Material-Specific Evaluation Metrics
def two_point_correlation(image, max_distance=None, debug=False):
    """
    Compute two-point correlation function for a binary image with robust error handling.
    
    The two-point correlation function S2(r) measures the probability that two points
    separated by distance r both lie in the same phase of the material. It's a key
    statistical descriptor used in materials science to characterize microstructures.
    
    Args:
        image (tensor or ndarray): Input image to analyze
        max_distance (int, optional): Maximum distance to compute correlation
        debug (bool): Whether to print debugging information
        
    Returns:
        tuple: (distances, correlation values)
    """
    # Convert tensor to numpy if needed
    if isinstance(image, torch.Tensor):
        image = image.cpu().numpy()
    
    image = image.squeeze()
    
    # Handle edge case: nearly constant images
    if np.all(image < 0.05) or np.all(image > 0.95):
        if debug:
            print(f"Warning: Image is nearly constant (min={image.min()}, max={image.max()})")
        # Return a flat correlation function
        if max_distance is None:
            max_distance = min(image.shape) // 2
        return np.arange(max_distance), np.zeros(max_distance)
    
    # Normalize to [0,1] range if needed
    if image.max() > 1:
        image = image / 255.0
    
    # Binarize with a threshold (0.5) for two-phase microstructure
    binary_img = (image > 0.5).astype(np.float32)
    
    if debug:
        print(f"Binary image stats: min={binary_img.min()}, max={binary_img.max()}, mean={binary_img.mean()}")
    
    h, w = binary_img.shape
    if max_distance is None:
        max_distance = min(h, w) // 2
    
    try:
        # Calculate autocorrelation using Fast Fourier Transform (FFT) for efficiency
        # This is much faster than direct calculation for large images
        ft = np.fft.fft2(binary_img)
        power_spectrum = np.abs(np.fft.ifft2(ft * np.conj(ft)))**2
        center = np.fft.fftshift(power_spectrum)[h//2, w//2]
        
        if debug:
            print(f"FFT stats: power_spectrum mean={np.mean(power_spectrum)}, center={center}")
        
        # Compute radial average (azimuthal integration)
        y, x = np.indices((h, w))
        r = np.sqrt((x - w//2)**2 + (y - h//2)**2).astype(np.int32)
        r_max = min(h//2, w//2, max_distance)
        
        # Handle edge case: very small images
        if r_max < 2:
            if debug:
                print(f"Warning: r_max is too small ({r_max})")
            r_max = 2
        
        # Calculate binned average safely
        tbin = np.bincount(r.ravel(), power_spectrum.ravel())
        nr = np.bincount(r.ravel())
        
        # Check for division by zero
        valid_indices = np.where(nr[:r_max] > 0)[0]
        
        if len(valid_indices) == 0:
            if debug:
                print("Warning: No valid indices for radial profile")
            return np.arange(r_max), np.zeros(r_max)
        
        # Initialize array and populate only valid indices
        radial_profile = np.zeros(r_max)
        radial_profile[valid_indices] = tbin[valid_indices] / nr[valid_indices]
        
        # Normalize only if center is not zero
        if abs(center) > 1e-10:
            radial_profile = radial_profile / center
        
        # Replace any NaN or inf values with zeros for stability
        radial_profile = np.nan_to_num(radial_profile, nan=0.0, posinf=0.0, neginf=0.0)
        
        if debug:
            print(f"Radial profile stats: min={radial_profile.min()}, max={radial_profile.max()}, mean={radial_profile.mean()}")
        
        return np.arange(r_max), radial_profile
        
    except Exception as e:
        # Robust error handling - return zeros instead of crashing
        if debug:
            print(f"Error in two_point_correlation: {str(e)}")
        return np.arange(max_distance or 10), np.zeros(max_distance or 10)

def lineal_path_function(image, max_distance=None):
    """
    Compute lineal path function for a binary image.
    
    The lineal path function L(r) gives the probability that a line segment of
    length r lies completely in one phase of the material. It characterizes the
    connectivity of the phases in the microstructure.
    
    Args:
        image (tensor or ndarray): Input image to analyze
        max_distance (int, optional): Maximum distance to compute
        
    Returns:
        tuple: (distances, lineal path values)
    """
    # Convert tensor to numpy if needed
    if isinstance(image, torch.Tensor):
        image = image.cpu().numpy()
    
    image = image.squeeze()
    if image.max() > 1:
        image = image / 255.0
    
    # Binarize the image
    image = (image > 0.5).astype(np.float32)
    
    h, w = image.shape
    if max_distance is None:
        # Limit to 1/4 of the smaller dimension for computational efficiency
        max_distance = min(h, w) // 4
    
    # Initialize storage for lineal path function
    lp = np.zeros(max_distance)
    counts = np.zeros(max_distance)
    
    # Sample a subset of points for efficiency (Monte Carlo approach)
    n_samples = 1000
    y_samples = np.random.randint(0, h, n_samples)
    x_samples = np.random.randint(0, w, n_samples)
    
    # Check horizontal lines
    for i, x in enumerate(x_samples):
        y = y_samples[i]
        if x + max_distance <= w:
            line = image[y, x:x+max_distance]
            for l in range(1, max_distance):
                # Check if all pixels in the line segment are in the phase (value > 0.5)
                if np.all(line[:l+1] > 0.5):
                    lp[l] += 1
                counts[l] += 1
    
    # Check vertical lines
    for i, y in enumerate(y_samples):
        x = x_samples[i]
        if y + max_distance <= h:
            line = image[y:y+max_distance, x]
            for l in range(1, max_distance):
                if np.all(line[:l+1] > 0.5):
                    lp[l] += 1
                counts[l] += 1
    
    # Normalize safely (avoid division by zero)
    with np.errstate(divide='ignore', invalid='ignore'):
        lp = np.divide(lp, counts, out=np.zeros_like(lp), where=counts!=0)
    
    return np.arange(max_distance), (lp_orig - lp_gen, r_orig)
        if area_orig > 0:
            lineal_discrepancy = (area_between / area_orig) * 100
        else:
            lineal_discrepancy = 0
        metrics['lineal_discrepancy'].append(lineal_discrepancy)
    
    # Compute average metrics
    metrics['avg_ssim'] = np.mean(metrics['ssim_values'])
    metrics['avg_s2_discrepancy'] = np.mean(metrics['s2_discrepancy'])
    metrics['avg_lineal_discrepancy'] = np.mean(metrics['lineal_discrepancy'])
    
    return metrics

In [None]:
# Main execution function
def run_training(config):
    """Set up and run the training process."""
    print("Configuration:", config)
    
    # Create save directory
    os.makedirs(config['save_dir'], exist_ok=True)
    
    # Save configuration
    with open(os.path.join(config['save_dir'], 'config.txt'), 'w') as f:
        for key, value in config.items():
            f.write(f"{key}: {value}\n")
    
    # Define transformations
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomAffine(degrees=10, scale=(0.9, 1.1)),
    ])
    
    # Load the dataset
    dataset = Micro2DKeyDataset(
        file_path=config['file_path'], 
        key=config['microstructure_class'],
        transform=transform if config['use_augmentation'] else None
    )
    
    data_loader = DataLoader(
        dataset, 
        batch_size=config['batch_size'], 
        shuffle=True,
        num_workers=config['num_workers'],
        pin_memory=True
    )
    
    print(f"Dataset size: {len(dataset)}")
    
    # Show sample images
    sample_batch = next(iter(data_loader))
    show_microstructure_batch(
        sample_batch, 
        n=4, 
        title=f"{config['microstructure_class']} Samples",
        save_path=os.path.join(config['save_dir'], 'original_samples.png')
    )
    
    # Initialize model and diffusion
    model = UNet(in_channels=1, out_channels=1)
    diffusion = DiffusionModel(
        timesteps=config['timesteps'],
        device=device
    )
    
    # Set up optimizer
    optimizer = optim.AdamW(
        model.parameters(), 
        lr=config['learning_rate'],
        weight_decay=config['weight_decay']
    )
    
    # Train the model
    trained_model = train_diffusion_model(
        dataloader=data_loader,
        model=model,
        diffusion=diffusion,
        optimizer=optimizer,
        config=config
    )
    # Evaluate and visulaize the results 
    metrics = evaluate_and_visualize(
        original_dataset=dataset,
        model=trained_model,
        diffusion=diffusion,
        config=config
    )
    
    print("Training complete!")
    print(f"Model and samples saved in: {config['save_dir']}")
    print(f"Average SSIM: {metrics['avg_ssim']:.4f}")
    print(f"Average Two-Point Correlation Discrepancy: {metrics['avg_s2_discrepancy']:.2f}%")
    print(f"Average Lineal Path Discrepancy: {metrics['avg_lineal_discrepancy']:.2f}%")


# Main execution
if __name__ == "__main__":
    # Configuration
    config = {
        'file_path': '/lustre/uschill-lab/users/3782/diffusion/New_Diff/MICRO2D_homogenized.h5',  # Path to HDF5 file
        'microstructure_class': 'VoidSmall',  # Class 
        'batch_size':16 ,
        'epochs': 50,
        'learning_rate': 5e-5,
        'weight_decay': 1e-5,
        'timesteps': 2000,  # Number of diffusion steps
        'sample_steps': 150,  # Steps for DDIM sampling
        'use_augmentation': True,
        'num_workers': 2,
        'save_dir': f'/lustre/uschill-lab/users/3782/diffusion/New_Diff/diffusion_model_{datetime.now().strftime("%Y%m%d_%H%M%S")}'
    }
    
    run_training(config)