In [1]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from diffusers import UNet2DModel, DDIMScheduler
from transformers import get_cosine_schedule_with_warmup
import h5py
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import sys

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Create output directories
os.makedirs("../Conditional_Diff/generated_samples", exist_ok=True)
os.makedirs("../Conditional_Diff/models", exist_ok=True)

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
# Improved dataset class
class MICRO2D_Dataset(Dataset):
    def __init__(self, file_path, microstructure_class='NBSA', property_type='mechanical', augment=True):
        self.file = h5py.File(file_path, 'r')
        self.images = self.file[microstructure_class][microstructure_class][:]
        
        if property_type == 'mechanical':
            self.properties = self.file[microstructure_class]['homogenized_mechanical'][:]
            self.prop_dim = 1
        else:
            self.properties = self.file[microstructure_class]['homogenized_thermal'][:]
            self.prop_dim = 1
            
        self.property_type = property_type
        self.microstructure_class = microstructure_class
        self.augment = augment
        
        # Precompute property min/max for normalization
        all_props = [self.properties[i][0][0] for i in range(len(self))]
        self.prop_min = min(all_props)
        self.prop_max = max(all_props)
        
        print(f"Loaded {len(self.images)} {microstructure_class} microstructures")
        print(f"Property shape: {self.properties.shape}")
        print(f"Property range: {self.prop_min:.2f} to {self.prop_max:.2f}")
        
    def __len__(self):
        return len(self.images)
        
    def __getitem__(self, idx):
        # Get and preprocess image
        image = torch.from_numpy(self.images[idx]).float().unsqueeze(0) * 2 - 1
        
        # Data augmentation
        if self.augment:
            k = np.random.randint(0, 4)  # 0: no rotation, 1: 90°, 2: 180°, 3: 270°
            image = torch.rot90(image, k, dims=[1, 2])
            
            # Random horizontal and vertical flips
            if np.random.random() > 0.5:
                image = torch.flip(image, dims=[2])
            if np.random.random() > 0.5:
                image = torch.flip(image, dims=[1])
        
        # Get property value
        prop_value = torch.tensor([self.properties[idx][0][0]], dtype=torch.float32)
        
        # Normalize property to [-1, 1] range
        norm_prop = 2 * (prop_value - self.prop_min) / (self.prop_max - self.prop_min) - 1
        
        # Create property channel
        prop_channel = norm_prop.view(1, 1, 1).repeat(1, 256, 256)
        
        # Combine image with property channel
        combined_input = torch.cat([image, prop_channel], dim=0)
        
        return {
            "pixel_values": image,
            "combined_input": combined_input,
            "properties": prop_value,
            "normalized_prop": norm_prop,
            "original_idx": idx
        }
    
    def denormalize_image(self, image):
        """Convert from [-1, 1] to [0, 1] range"""
        return (image + 1) / 2
    
    def close(self):
        """Close the HDF5 file"""
        self.file.close()

In [4]:
# Function to analyze property distribution
def analyze_property_distribution(file_path, class_name='NBSA', property_type='mechanical'):
    """Analyze the distribution of properties in the dataset"""
    
    # Open the HDF5 file
    with h5py.File(file_path, 'r') as file:
        # Get properties
        if property_type == 'mechanical':
            properties = file[class_name]['homogenized_mechanical'][:]
        else:
            properties = file[class_name]['homogenized_thermal'][:]
        
        # Extract the first property of the first material combination
        primary_props = [properties[i][0][0] for i in range(len(properties))]
        
        # Get statistics
        min_val = min(primary_props)
        max_val = max(primary_props)
        mean_val = sum(primary_props) / len(primary_props)
        median_val = sorted(primary_props)[len(primary_props)//2]
        
        print(f"Property statistics for {class_name} ({property_type}):")
        print(f"  Min: {min_val:.2f}")
        print(f"  Max: {max_val:.2f}")
        print(f"  Mean: {mean_val:.2f}")
        print(f"  Median: {median_val:.2f}")
        
        # Create histogram bins
        hist, bins = np.histogram(primary_props, bins=10)
        
        print(f"  Distribution by bins:")
        for i, (start, end) in enumerate(zip(bins[:-1], bins[1:])):
            print(f"    {start:.2f} - {end:.2f}: {hist[i]} samples")
        
        # Create a plot of the distribution
        plt.figure(figsize=(10, 6))
        plt.hist(primary_props, bins=20)
        plt.xlabel('Property Value')
        plt.ylabel('Count')
        plt.title(f'Distribution of Primary Property for {class_name}')
        plt.savefig(f"{class_name}_property_distribution.png")
        plt.close()
        
        # Suggest values spanning the range
        suggested_values = np.linspace(min_val, max_val, 6).tolist()
        suggested_values = [round(val) for val in suggested_values]
        
        print(f"\nSuggested property values for generation (spanning the dataset range):")
        print(f"  {suggested_values}")
        
        return min_val, max_val, suggested_values

In [None]:
# generator function
def generate_and_save_images(model, scheduler, dataset, epoch, num_samples=4, steps=500, guidance_scale=3.0):
    model.eval()
    
    with torch.no_grad():
        # Select random samples from dataset for conditioning
        indices = np.random.choice(len(dataset), num_samples, replace=False)
        generated_images = []
        property_values = []
        
        for i, idx in enumerate(indices):
            print(f"Generating sample {i+1}/{num_samples}")
            
            batch = dataset[idx]
            properties = batch["properties"].to(device)
            norm_prop = batch["normalized_prop"].to(device)
            property_values.append(properties.item())
            
            img_size = 256
            
            # Create property channel for conditioning
            prop_channel = norm_prop.view(1, 1, 1, 1).repeat(1, 1, img_size, img_size).to(device)
            
            # Start with random noise for the image channel
            noise = torch.randn(1, 1, img_size, img_size).to(device)
            
            # Initialize input with noise for image and property for conditioning
            noisy_image = torch.cat([noise, prop_channel], dim=1)
            
            # Create null property channel for classifier-free guidance
            null_channel = torch.zeros_like(prop_channel)
            
            # Single print for each sample, updated in place
            print(f"Denoising sample {i+1}/{num_samples}: 0%", end="\r")
            
            steps_to_use = min(steps, len(scheduler.timesteps))
            for t_idx, t in enumerate(scheduler.timesteps[:steps_to_use]):
                # Update progress percentage
                progress = (t_idx + 1) / steps_to_use * 100
                print(f"Denoising sample {i+1}/{num_samples}: {progress:.1f}%", end="\r")
                
                # Get model prediction
                noise_pred = model(noisy_image, t).sample
                
                # For classifier-free guidance (if enabled)
                if guidance_scale > 1.0:
                    # Get unconditional prediction
                    unconditional_input = torch.cat([noisy_image[:, :1], null_channel], dim=1)
                    unconditional_pred = model(unconditional_input, t).sample
                    
                    # Apply guidance
                    guided_pred = unconditional_pred + guidance_scale * (noise_pred - unconditional_pred)
                    
                    # Update only the image part, keeping property channel fixed
                    updated = scheduler.step(guided_pred[:, :1], t, noisy_image[:, :1]).prev_sample
                else:
                    # Standard denoising without guidance
                    updated = scheduler.step(noise_pred[:, :1], t, noisy_image[:, :1]).prev_sample
                
                noisy_image = torch.cat([updated, prop_channel], dim=1)
                
                # Clear GPU cache periodically
                if t_idx % 100 == 0:
                    torch.cuda.empty_cache()
            
            # Print newline after completion
            print()
            
            # Get only the image channel and clamp values
            generated_img = torch.clamp(noisy_image[:, :1], -1, 1)
            generated_images.append(generated_img.cpu())
            
            # Clear GPU cache
            torch.cuda.empty_cache()
        
        # Create a grid of ONLY generated images
        fig, axes = plt.subplots(1, num_samples, figsize=(4*num_samples, 4))
        
        # Handle case of single sample
        if num_samples == 1:
            axes = np.array([axes])
        
        for i in range(num_samples):
            # Generated image
            gen_img = dataset.denormalize_image(generated_images[i].squeeze()).numpy()
            
            if num_samples > 1:
                axes[i].imshow(gen_img, cmap='gray')
                axes[i].set_title(f"Prop: {property_values[i]:.3f}")
                axes[i].axis('off')
            else:
                axes.imshow(gen_img, cmap='gray')
                axes.set_title(f"Prop: {property_values[i]:.3f}")
                axes.axis('off')
        
        plt.tight_layout()
        plt.savefig(f"generated_samples/epoch_{epoch}.png", dpi=300)
        plt.close()
        
        print(f"Generated samples saved to generated_samples/epoch_{epoch}.png")

In [None]:
# Function to train the property conditional diffusion model
def train_property_conditional_diffusion():
    """Train the property conditional diffusion model"""
    # Configuration
    file_path = "../MICRO2D_homogenized.h5"  # Update to your file path
    microstructure_class = 'NBSA'
    property_type = 'mechanical'
    batch_size = 4  # Reduced batch size
    num_epochs = 200  # Increased epochs
    learning_rate = 5e-5  # Reduced learning rate
    
    # Create dataset and dataloader
    dataset = MICRO2D_Dataset(file_path, microstructure_class, property_type, augment=True)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    
    # Create model with increased capacity
    model = UNet2DModel(
        sample_size=256,
        in_channels=2,    
        out_channels=2,   
        layers_per_block=3,  # Increased from 2
        block_out_channels=(128, 256, 512, 512),  # Increased capacity
        down_block_types=(
            "DownBlock2D",
            "DownBlock2D", 
            "AttnDownBlock2D",  # Added attention to more layers
            "AttnDownBlock2D",
        ),
        up_block_types=(
            "AttnUpBlock2D", 
            "AttnUpBlock2D",  # Added attention to more layers
            "UpBlock2D",
            "UpBlock2D", 
        )
    )
    
    # Move model to device
    model.to(device)
    
    # Use more advanced scheduler
    noise_scheduler = DDIMScheduler(
        num_train_timesteps=1000,
        beta_schedule="scaled_linear",
        clip_sample=False,
        prediction_type="epsilon"
    )
    
    # Create optimizer with weight decay
    optimizer = torch.optim.AdamW(
        model.parameters(), 
        lr=learning_rate,
        weight_decay=0.01
    )
    
    # Create learning rate scheduler
    lr_scheduler = get_cosine_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=500,
        num_training_steps=len(dataloader) * num_epochs
    )
    
    # Training loop
    global_step = 0
    
    for epoch in range(num_epochs):
        model.train()
        
        # track epoch progress
        epoch_loss = 0.0
        batch_count = 0
        
        # Simple epoch header
        print(f"Epoch {epoch+1}/{num_epochs} - Training:")
        
        for step, batch in enumerate(dataloader):
            clean_images = batch["pixel_values"].to(device)
            combined_input = batch["combined_input"].to(device)
            
            # Sample noise
            noise = torch.randn(clean_images.shape).to(device)
            batch_size = clean_images.shape[0]
            
            # Sample a random timestep for each image
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps, 
                (batch_size,), device=device
            ).long()
            
            # Add noise to the clean images
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
            
            # Create noisy combined input (noise applied only to image channel)
            noisy_combined = torch.cat([noisy_images, combined_input[:, 1:]], dim=1)
            
            # Get model prediction
            noise_pred = model(noisy_combined, timesteps).sample
            
            # Calculate loss only on the image channel
            loss = F.mse_loss(noise_pred[:, :1], noise)
            
            # Update model
            loss.backward()
            
            # Add gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            
            epoch_loss += loss.item()
            batch_count += 1
            global_step += 1
            
            # Print occasional updates without a progress bar
            if step % 10 == 0:
                print(f"  Batch {step}/{len(dataloader)} - Loss: {loss.item():.6f}", end="\r")
            
            # Periodically clear cache
            if step % 10 == 0:
                torch.cuda.empty_cache()
        
        # Log average loss for the epoch
        avg_epoch_loss = epoch_loss / batch_count
        print(f"\nEpoch {epoch+1} completed - Average Loss: {avg_epoch_loss:.6f}")
        
        # Generate and save images every 10 epochs or on the final epoch
        if (epoch + 1) % 10 == 0 or (epoch + 1) == num_epochs:
            print(f"Generating images for epoch {epoch+1}...")
            generate_and_save_images(
                model, 
                noise_scheduler, 
                dataset, 
                epoch+1, 
                steps=250,  # Reduced steps for faster generation
                guidance_scale=3.0  # Enable classifier-free guidance
            )
    
    # Save final model
    torch.save({
        'model_state_dict': model.state_dict(),
        'property_type': property_type,
        'property_min': dataset.prop_min,
        'property_max': dataset.prop_max,
    }, "models/nbsa_diffusion_final.pt")
    
    # Close the dataset HDF5 file
    dataset.close()
    
    print("Training complete!")

In [7]:
# Function to generate custom microstructures
def generate_custom_microstructures(model_path, file_path, property_values, num_samples=4, class_name='NBSA'):
    """
    Generate custom microstructures with specified property values
    
    Args:
        model_path: Path to the saved model
        file_path: Path to the HDF5 file (for dataset info)
        property_values: List of property values to condition generation on
        num_samples: Number of samples to generate per property value
        class_name: Microstructure class to use
    """
    # Load model checkpoint
    checkpoint = torch.load(model_path, map_location=device)
    
    # Create model with the improved architecture
    model = UNet2DModel(
        sample_size=256,
        in_channels=2,
        out_channels=2,
        layers_per_block=3,
        block_out_channels=(128, 256, 512, 512),
        down_block_types=(
            "DownBlock2D",
            "DownBlock2D", 
            "AttnDownBlock2D",
            "AttnDownBlock2D",
        ),
        up_block_types=(
            "AttnUpBlock2D", 
            "AttnUpBlock2D",
            "UpBlock2D",
            "UpBlock2D", 
        )
    )
    
    # Load model weights
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    # Create scheduler - use DDIM for better quality
    scheduler = DDIMScheduler(
        num_train_timesteps=1000,
        beta_schedule="scaled_linear",
        clip_sample=False,
        prediction_type="epsilon"
    )
    
    # Create a temporary dataset to get the property normalization
    temp_dataset = MICRO2D_Dataset(file_path, class_name)
    
    # If property_min and property_max are in the checkpoint, use those instead
    prop_min = checkpoint.get('property_min', temp_dataset.prop_min)
    prop_max = checkpoint.get('property_max', temp_dataset.prop_max)
    
    # Guidance scale for classifier-free guidance
    guidance_scale = 3.0
    
    # Generate images for each property value
    all_generated = []
    
    with torch.no_grad():
        for prop_val in property_values:
            print(f"Generating samples for property value: {prop_val}")
            
            for i in range(num_samples):
                # Create property tensor and normalize it
                norm_prop = 2 * (prop_val - prop_min) / (prop_max - prop_min) - 1
                prop_tensor = torch.tensor([[norm_prop]], dtype=torch.float32).to(device)
                prop_channel = prop_tensor.view(1, 1, 1, 1).repeat(1, 1, 256, 256)
                
                # Create null property channel for classifier-free guidance
                null_channel = torch.zeros_like(prop_channel)
                
                # Start with random noise
                noise = torch.randn(1, 1, 256, 256).to(device)
                
                # Combine noise with property channel
                noisy_image = torch.cat([noise, prop_channel], dim=1)
                
                # Single print for this sample, updated in place
                print(f"  Sample {i+1}/{num_samples}: 0%", end="\r")
                
                # Sampling loop - more steps for higher quality
                steps_to_use = 500
                for t_idx, t in enumerate(scheduler.timesteps[:steps_to_use]):
                    # Update progress inline
                    progress = (t_idx + 1) / steps_to_use * 100
                    print(f"  Sample {i+1}/{num_samples}: {progress:.1f}%", end="\r")
                    
                    # Get conditional prediction
                    noise_pred = model(noisy_image, t).sample
                    
                    # For classifier-free guidance
                    if guidance_scale > 1.0:
                        # Get unconditional prediction
                        unconditional_input = torch.cat([noisy_image[:, :1], null_channel], dim=1)
                        unconditional_pred = model(unconditional_input, t).sample
                        
                        # Apply guidance
                        guided_pred = unconditional_pred + guidance_scale * (noise_pred - unconditional_pred)
                        
                        # Update only the image part, keeping property channel fixed
                        updated = scheduler.step(guided_pred[:, :1], t, noisy_image[:, :1]).prev_sample
                    else:
                        # Standard denoising without guidance
                        updated = scheduler.step(noise_pred[:, :1], t, noisy_image[:, :1]).prev_sample
                    
                    noisy_image = torch.cat([updated, prop_channel], dim=1)
                
                print()  # New line after completion
                
                # Get only the image channel and clamp values
                generated_img = torch.clamp(noisy_image[:, :1], -1, 1).cpu()
                all_generated.append((prop_val, generated_img))
                
                # Clear GPU cache
                torch.cuda.empty_cache()
        
        # Create a grid of images
        rows = len(property_values)
        cols = num_samples
        fig, axes = plt.subplots(rows, cols, figsize=(3*cols, 3*rows))
        
        # Handle single row/column cases
        if rows == 1 and cols == 1:
            axes = np.array([[axes]])
        elif rows == 1:
            axes = axes.reshape(1, -1)
        elif cols == 1:
            axes = axes.reshape(-1, 1)
        
        for row, prop_val in enumerate(property_values):
            for col in range(num_samples):
                idx = row * num_samples + col
                prop, img = all_generated[idx]
                
                # Display image
                axes[row, col].imshow(temp_dataset.denormalize_image(img.squeeze()).numpy(), cmap='gray')
                axes[row, col].set_title(f"Prop: {prop:.2f}")
                axes[row, col].axis('off')
        
        plt.tight_layout()
        plt.savefig(f"generated_samples/custom_properties.png", dpi=300)
        plt.close()
    
    # Close the dataset
    temp_dataset.close()
    
    print("Custom microstructure generation complete!")

In [None]:
# Main execution
if __name__ == "__main__":
    # Analyze property distribution first
    min_val, max_val, suggested_values = analyze_property_distribution(
        file_path="../MICRO2D_homogenized.h5",
        class_name='NBSA',
        property_type='mechanical'
    )
    
    # Use the suggested values for generation
    property_values = suggested_values
    # Train the model
    train_property_conditional_diffusion()

    

    """
    # Generate custom microstructures
    generate_custom_microstructures(
         model_path="models/nbsa_diffusion_final.pt",
        file_path="../MICRO2D_homogenized.h5",
            property_values=property_values,
            num_samples=4
        )
    """

Property statistics for NBSA (mechanical):
  Min: 1706.90
  Max: 2364.29
  Mean: 1944.66
  Median: 1936.45
  Distribution by bins:
    1706.90 - 1772.64: 88 samples
    1772.64 - 1838.37: 289 samples
    1838.37 - 1904.11: 294 samples
    1904.11 - 1969.85: 288 samples
    1969.85 - 2035.59: 299 samples
    2035.59 - 2101.33: 204 samples
    2101.33 - 2167.07: 113 samples
    2167.07 - 2232.81: 44 samples
    2232.81 - 2298.55: 11 samples
    2298.55 - 2364.29: 4 samples

Suggested property values for generation (spanning the dataset range):
  [1707, 1838, 1970, 2101, 2233, 2364]
Loaded 1634 NBSA microstructures
Property shape: (1634, 6, 5)
Property range: 1706.90 to 2364.29
Epoch 1/200 - Training:
  Batch 400/409 - Loss: 0.083586
Epoch 1 completed - Average Loss: 0.282838
Epoch 2/200 - Training:
  Batch 400/409 - Loss: 0.050484
Epoch 2 completed - Average Loss: 0.093696
Epoch 3/200 - Training:
  Batch 400/409 - Loss: 0.047192
Epoch 3 completed - Average Loss: 0.066010
Epoch 4/200 - Tr