In [None]:
# Import necessary libraries
import os
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.utils import make_grid, save_image
import matplotlib.pyplot as plt
import nibabel as nib
from tqdm.notebook import tqdm  # Use notebook version for better display
import random
import copy
import math
from pathlib import Path
from contextlib import nullcontext

# Optional imports with fallbacks
try:
    from skimage import measure, metrics
    HAVE_SKIMAGE = True
except ImportError:
    HAVE_SKIMAGE = False
    print("Warning: scikit-image not installed. Some visualization features will be limited.")

try:
    import plotly.graph_objects as go
    HAVE_PLOTLY = True
except ImportError:
    HAVE_PLOTLY = False
    print("Warning: plotly not installed. 3D visualizations will be limited.")

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
logger = logging.getLogger(__name__)

# Set reproducibility
def set_seeds(seed=42):
    """Set random seeds for reproducibility."""
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    # Note: Setting to True may impact performance, but ensures reproducibility

set_seeds()

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")

# Memory information
if torch.cuda.is_available():
    logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
    logger.info(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

In [None]:
class BrainDiffusionConfig:
    """Configuration class for BrainDiffusion model.

    This class centralizes all hyperparameters and settings in one place.
    """

    def __init__(self):
        # Data Parameters
        self.volume_size = (64, 64, 64)  # 3D volume dimensions (D, H, W)
        self.in_channels = 1             # Single channel for T1/T2 MRI
        self.patch_size = None           # Set to use patch-based training (e.g., (32, 32, 32))

        # Model Parameters
        self.base_channels = 32          # Base channel count (smaller for 3D due to memory constraints)
        self.channel_mults = (1, 2, 4, 8)  # Channel multipliers at each resolution
        self.attention_resolutions = (8,) # At which resolutions to apply attention
        self.num_res_blocks = 2          # Number of residual blocks per resolution
        self.dropout = 0.1               # Dropout rate
        self.time_dim = 256              # Dimension for time embedding

        # Diffusion Parameters
        self.num_diffusion_steps = 1000  # Number of diffusion timesteps
        self.beta_start = 1e-4           # Starting noise schedule value
        self.beta_end = 0.02             # Ending noise schedule value

        # Training Parameters
        self.batch_size = 2              # Small batch size for 3D (memory constraints)
        self.learning_rate = 1e-4        # Learning rate
        self.weight_decay = 1e-5         # Weight decay for regularization
        self.epochs = 100                # Number of training epochs
        self.ema_decay = 0.9999          # Exponential moving average decay
        self.train_val_split = 0.8       # Train/validation split ratio

        # Memory Optimization
        self.use_checkpointing = True    # Use gradient checkpointing to save memory
        self.mixed_precision = True      # Use mixed precision training

        # Sampling Parameters
        self.sampling_steps = 50         # Default sampling steps (for DDIM)
        self.use_ddim = True             # Use DDIM for faster sampling (vs. DDPM)
        self.ddim_sampling_eta = 0.0     # DDIM sampling parameter (0 = deterministic)
        self.guidance_scale = 2.0        # Classifier-free guidance scale

    def __str__(self):
        """String representation of the configuration."""
        return "\n".join(f"{key} = {value}" for key, value in vars(self).items())

    def update(self, **kwargs):
        """Update configuration parameters."""
        for key, value in kwargs.items():
            if hasattr(self, key):
                setattr(self, key, value)
            else:
                raise ValueError(f"Unknown parameter: {key}")
        return self

# Create default configuration
config = BrainDiffusionConfig()
print("Default configuration:")
print(config)

Default configuration:
volume_size = (64, 64, 64)
in_channels = 1
patch_size = None
base_channels = 32
channel_mults = (1, 2, 4, 8)
attention_resolutions = (8,)
num_res_blocks = 2
dropout = 0.1
time_dim = 256
num_diffusion_steps = 1000
beta_start = 0.0001
beta_end = 0.02
batch_size = 2
learning_rate = 0.0001
weight_decay = 1e-05
epochs = 100
ema_decay = 0.9999
train_val_split = 0.8
use_checkpointing = True
mixed_precision = True
sampling_steps = 50
use_ddim = True
ddim_sampling_eta = 0.0
guidance_scale = 2.0


In [None]:
class BrainVolumeLoader:
    """Handles loading and preprocessing of 3D brain MRI volumes."""

    def __init__(self, data_dir, target_size=(64, 64, 64), normalize=True):
        """Initialize the volume loader.

        Args:
            data_dir: Path to directory containing NIfTI files
            target_size: Target dimensions for volumes (D, H, W)
            normalize: Whether to normalize intensity values to [0, 1]
        """
        self.data_dir = Path(data_dir)
        self.target_size = target_size
        self.normalize = normalize
        logger.info(f"Initialized volume loader with target size: {target_size}")

    def load_nifti_volume(self, file_path):
        """Load a NIfTI file and return as numpy array.

        Args:
            file_path: Path to NIfTI file

        Returns:
            volume: Numpy array of volume data
            affine: Affine transformation matrix
        """
        img = nib.load(file_path)
        return img.get_fdata(), img.affine

    def resample_volume(self, volume, target_shape):
        """Resample volume to target shape.

        Args:
            volume: Input volume as numpy array
            target_shape: Target shape (D, H, W)

        Returns:
            Resampled volume
        """
        # Convert to tensor for resampling
        tensor = torch.from_numpy(volume).float().unsqueeze(0).unsqueeze(0)

        # Resample using trilinear interpolation
        resampled = F.interpolate(
            tensor,
            size=target_shape,
            mode='trilinear',
            align_corners=False
        )

        return resampled.squeeze(0).squeeze(0).numpy()

    def normalize_volume(self, volume):
        """Normalize volume intensities to [0, 1].

        Args:
            volume: Input volume

        Returns:
            Normalized volume
        """
        min_val = np.min(volume)
        max_val = np.max(volume)

        if max_val > min_val:
            normalized = (volume - min_val) / (max_val - min_val)
        else:
            normalized = volume * 0

        return normalized

    def preprocess_volume(self, volume):
        """Preprocess a 3D volume.

        Args:
            volume: Input volume

        Returns:
            Preprocessed volume as tensor
        """
        # Resample to target dimensions if needed
        if volume.shape != self.target_size:
            volume = self.resample_volume(volume, self.target_size)

        # Normalize intensities if requested
        if self.normalize:
            volume = self.normalize_volume(volume)

        # Convert to tensor
        tensor = torch.from_numpy(volume).float().unsqueeze(0)  # Add channel dimension

        return tensor

    def extract_region_from_path(self, file_path):
        """Extract brain region from filepath or use a default.

        Args:
            file_path: Path to NIfTI file

        Returns:
            Region identifier (integer)
        """
        # This is a placeholder - adapt based on your dataset organization
        file_path_str = str(file_path).lower()

        if "frontal" in file_path_str:
            return 0  # Frontal lobe
        elif "parietal" in file_path_str:
            return 1  # Parietal lobe
        elif "temporal" in file_path_str:
            return 2  # Temporal lobe
        elif "occipital" in file_path_str:
            return 3  # Occipital lobe
        else:
            return 4  # Other/unknown

    def load_dataset(self, limit=None):
        """Load all NIfTI files in the data directory.

        Args:
            limit: Optional limit on number of files to load

        Returns:
            volumes: List of preprocessed volumes as tensors
            region_labels: List of region labels
        """
        # Find all NIfTI files
        nifti_files = list(self.data_dir.glob('**/*.nii.gz'))
        nifti_files.extend(list(self.data_dir.glob('**/*.nii')))

        if limit:
            nifti_files = nifti_files[:limit]

        logger.info(f"Found {len(nifti_files)} NIfTI files")

        volumes = []
        region_labels = []

        for file_path in tqdm(nifti_files, desc="Loading 3D volumes"):
            try:
                # Get brain region
                region = self.extract_region_from_path(file_path)

                # Load and preprocess volume
                volume_data, _ = self.load_nifti_volume(file_path)
                processed_volume = self.preprocess_volume(volume_data)

                volumes.append(processed_volume)
                region_labels.append(region)

            except Exception as e:
                logger.warning(f"Error processing {file_path}: {e}")

        logger.info(f"Successfully loaded {len(volumes)} volumes")
        return volumes, region_labels

In [None]:
class BrainVolumeAugmentation:
    """3D data augmentation techniques for brain volumes."""

    @staticmethod
    def random_flip(volume, axis=0, p=0.5):
        """Randomly flip volume along specified axis.

        Args:
            volume: Input volume [C, D, H, W]
            axis: Axis to flip (0=depth, 1=height, 2=width)
            p: Probability of applying flip

        Returns:
            Augmented volume
        """
        if random.random() < p:
            # Add 1 to axis because the first dimension is the channel
            return torch.flip(volume, dims=[axis+1])
        return volume

    @staticmethod
    def random_rotate_90(volume, axes=(1, 2), p=0.5):
        """Randomly rotate volume 90 degrees around specified axes.

        Args:
            volume: Input volume [C, D, H, W]
            axes: Tuple of axes to rotate around (after channel dim)
            p: Probability of applying rotation

        Returns:
            Augmented volume
        """
        if random.random() < p:
            k = random.randint(1, 3)  # 1, 2, or 3 times 90 degrees
            # Add 1 to axes because the first dimension is the channel
            adjusted_axes = (axes[0]+1, axes[1]+1)
            return torch.rot90(volume, k=k, dims=adjusted_axes)
        return volume

    @staticmethod
    def random_intensity_shift(volume, max_offset=0.1, p=0.5):
        """Randomly shift intensity values.

        Args:
            volume: Input volume [C, D, H, W]
            max_offset: Maximum intensity offset
            p: Probability of applying shift

        Returns:
            Augmented volume
        """
        if random.random() < p:
            offset = random.uniform(-max_offset, max_offset)
            shifted = volume + offset
            return torch.clamp(shifted, 0, 1)
        return volume

    @staticmethod
    def random_intensity_scale(volume, min_scale=0.9, max_scale=1.1, p=0.5):
        """Randomly scale intensity values.

        Args:
            volume: Input volume [C, D, H, W]
            min_scale: Minimum scaling factor
            max_scale: Maximum scaling factor
            p: Probability of applying scaling

        Returns:
            Augmented volume
        """
        if random.random() < p:
            scale = random.uniform(min_scale, max_scale)
            scaled = volume * scale
            return torch.clamp(scaled, 0, 1)
        return volume

    @staticmethod
    def random_gaussian_noise(volume, std=0.01, p=0.5):
        """Add random Gaussian noise.

        Args:
            volume: Input volume [C, D, H, W]
            std: Standard deviation of noise
            p: Probability of adding noise

        Returns:
            Augmented volume
        """
        if random.random() < p:
            noise = torch.randn_like(volume) * std
            noisy = volume + noise
            return torch.clamp(noisy, 0, 1)
        return volume

    @staticmethod
    def apply_augmentations(volume, p=0.8):
        """Apply a series of random augmentations with probability p.

        Args:
            volume: Input volume [C, D, H, W]
            p: Overall probability of applying augmentation

        Returns:
            Augmented volume
        """
        if random.random() < p:
            # Apply multiple augmentations
            augs = [
                lambda v: BrainVolumeAugmentation.random_flip(v, axis=0),
                lambda v: BrainVolumeAugmentation.random_flip(v, axis=1),
                lambda v: BrainVolumeAugmentation.random_flip(v, axis=2),
                lambda v: BrainVolumeAugmentation.random_rotate_90(v, axes=(0, 1)),
                lambda v: BrainVolumeAugmentation.random_rotate_90(v, axes=(0, 2)),
                lambda v: BrainVolumeAugmentation.random_rotate_90(v, axes=(1, 2)),
                lambda v: BrainVolumeAugmentation.random_intensity_shift(v),
                lambda v: BrainVolumeAugmentation.random_intensity_scale(v),
                lambda v: BrainVolumeAugmentation.random_gaussian_noise(v)
            ]

            # Apply 1-3 random augmentations
            num_augs = random.randint(1, 3)
            selected_augs = random.sample(augs, num_augs)

            for aug in selected_augs:
                volume = aug(volume)

        return volume

In [None]:
class BrainVolumeDataset(Dataset):
    """PyTorch Dataset for 3D brain volumes with conditional region labels."""

    def __init__(self, volumes, region_labels=None, transform=None, patch_size=None):
        """Initialize the dataset.

        Args:
            volumes: List of volume tensors [C, D, H, W]
            region_labels: Optional list of region labels
            transform: Optional transform function
            patch_size: Optional patch size for patch-based extraction
        """
        self.volumes = volumes
        self.region_labels = region_labels
        self.transform = transform
        self.patch_size = patch_size

        # Calculate number of unique regions for embedding dimension
        if region_labels is not None:
            self.num_regions = len(set(region_labels))
            logger.info(f"Dataset has {len(volumes)} volumes with {self.num_regions} regions")
        else:
            self.num_regions = 0

    def __len__(self):
        """Return the number of volumes in the dataset."""
        return len(self.volumes)

    def extract_random_patch(self, volume):
        """Extract a random patch from the volume.

        Args:
            volume: Input volume [C, D, H, W]

        Returns:
            Extracted patch
        """
        if self.patch_size is None:
            return volume

        c, d, h, w = volume.shape
        pd, ph, pw = self.patch_size

        # Randomly select patch origin
        d_start = random.randint(0, d - pd) if d > pd else 0
        h_start = random.randint(0, h - ph) if h > ph else 0
        w_start = random.randint(0, w - pw) if w > pw else 0

        # Extract patch
        patch = volume[:, d_start:d_start+pd, h_start:h_start+ph, w_start:w_start+pw]

        return patch

    def __getitem__(self, idx):
        """Get a volume and its label.

        Args:
            idx: Index of the volume to retrieve

        Returns:
            Dictionary containing volume and region label
        """
        volume = self.volumes[idx]

        # Extract patch if needed
        if self.patch_size is not None:
            volume = self.extract_random_patch(volume)

        # Apply transforms if provided
        if self.transform:
            volume = self.transform(volume)

        # Get region label if available
        region = -1  # Default for unconditional
        if self.region_labels is not None:
            region = self.region_labels[idx]

        return {
            'volume': volume,         # [C, D, H, W] tensor
            'region': torch.tensor(region, dtype=torch.long)  # Class index
        }


def create_dataloaders(volumes, region_labels=None, config=None):
    """Create training and validation dataloaders for 3D volumes.

    Args:
        volumes: List of volume tensors
        region_labels: Optional list of region labels
        config: Configuration object

    Returns:
        train_loader: Training data loader
        val_loader: Validation data loader
        num_regions: Number of unique regions
    """
    if config is None:
        config = BrainDiffusionConfig()

    # Create indices for train/val split
    indices = list(range(len(volumes)))
    random.shuffle(indices)
    split = int(len(indices) * config.train_val_split)

    train_indices = indices[:split]
    val_indices = indices[split:]

    # Split data
    train_volumes = [volumes[i] for i in train_indices]
    train_regions = None if region_labels is None else [region_labels[i] for i in train_indices]

    val_volumes = [volumes[i] for i in val_indices]
    val_regions = None if region_labels is None else [region_labels[i] for i in val_indices]

    # Create datasets
    train_dataset = BrainVolumeDataset(
        train_volumes,
        train_regions,
        transform=BrainVolumeAugmentation.apply_augmentations,
        patch_size=config.patch_size
    )

    val_dataset = BrainVolumeDataset(
        val_volumes,
        val_regions,
        patch_size=config.patch_size
    )

    # Get number of regions
    num_regions = train_dataset.num_regions if hasattr(train_dataset, 'num_regions') else 0

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    logger.info(f"Created dataloaders with {len(train_loader)} training batches and {len(val_loader)} validation batches")

    return train_loader, val_loader, num_regions

In [None]:
class DiffusionModel3D:
    """Core diffusion model implementation for 3D volumes."""

    def __init__(self, config=None, device=device):
        """Initialize the diffusion process parameters.

        Args:
            config: Configuration object
            device: Device to use for computations
        """
        # Use default config if not provided
        if config is None:
            config = BrainDiffusionConfig()

        self.num_diffusion_steps = config.num_diffusion_steps
        self.beta_start = config.beta_start
        self.beta_end = config.beta_end
        self.device = device

        # Define noise schedule (linear or cosine)
        self.betas = self._linear_beta_schedule(
            self.beta_start,
            self.beta_end,
            self.num_diffusion_steps
        ).to(device)

        # Precompute values for forward and reverse processes
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)

        # Calculations for diffusion q(x_t | x_{t-1})
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)

        # Calculations for posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)

    def _linear_beta_schedule(self, beta_start, beta_end, num_diffusion_steps):
        """Linear noise schedule.

        Args:
            beta_start: Starting noise schedule value
            beta_end: Ending noise schedule value
            num_diffusion_steps: Number of diffusion steps

        Returns:
            Tensor of beta values
        """
        return torch.linspace(beta_start, beta_end, num_diffusion_steps)

    def _cosine_beta_schedule(self, num_diffusion_steps, s=0.008):
        """Cosine noise schedule as proposed in improved DDPM papers.

        Args:
            num_diffusion_steps: Number of diffusion steps
            s: Offset parameter

        Returns:
            Tensor of beta values
        """
        steps = num_diffusion_steps + 1
        t = torch.linspace(0, num_diffusion_steps, steps) / num_diffusion_steps
        alphas_cumprod = torch.cos((t + s) / (1 + s) * torch.pi * 0.5) ** 2
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
        return torch.clip(betas, 0, 0.999)

    def q_sample(self, x_0, t, noise=None):
        """Forward diffusion process: q(x_t | x_0)

        Args:
            x_0: Original clean volumes [B, C, D, H, W]
            t: Diffusion timesteps [B]
            noise: Optional pre-generated noise

        Returns:
            x_t: Noised volumes at timestep t
            noise: The noise used
        """
        if noise is None:
            noise = torch.randn_like(x_0)

        # Get appropriate values for timestep t
        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].reshape(-1, 1, 1, 1, 1)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].reshape(-1, 1, 1, 1, 1)

        # Forward process formula: q(x_t | x_0) = sqrt(ɑt)x_0 + sqrt(1-ɑt)ε
        x_t = sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise

        return x_t, noise

    def compute_loss(self, model, x_0, condition, t=None):
        """Compute training loss for the denoising model.

        Args:
            model: Denoising U-Net model
            x_0: Original clean volumes [B, C, D, H, W]
            condition: Conditioning information
            t: Optional specific timesteps, otherwise random

        Returns:
            Loss value
        """
        B = x_0.shape[0]

        # Sample random timesteps if not provided
        if t is None:
            t = torch.randint(0, self.num_diffusion_steps, (B,), device=self.device)

        # Forward process to get noisy volumes x_t
        x_t, noise = self.q_sample(x_0, t)

        # Predict the noise using the model
        noise_pred = model(x_t, t, condition)

        # Simple MSE loss between actual and predicted noise
        loss = F.mse_loss(noise_pred, noise)

        return loss

In [None]:
# Extend DiffusionModel3D with sampling methods

def p_sample(self, model, x_t, t, condition, guidance_scale=1.0):
    """Single step of the reverse diffusion sampling process.

    Args:
        model: Denoising model
        x_t: Noisy volume at timestep t
        t: Current timestep
        condition: Conditioning information
        guidance_scale: Scale for classifier-free guidance

    Returns:
        Denoised sample for timestep t-1
    """
    # Get beta and alpha values for timestep t
    betas_t = self.betas[t]
    sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t]
    sqrt_recip_alphas_t = self.sqrt_recip_alphas[t]

    # Model prediction with guidance if requested
    if guidance_scale > 1.0 and condition is not None:
        # Predict with conditioning
        noise_pred_cond = model(x_t, t, condition)

        # Predict without conditioning (unconditional)
        noise_pred_uncond = model(x_t, t, None)

        # Apply classifier-free guidance
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
    else:
        # Regular conditional or unconditional generation
        noise_pred = model(x_t, t, condition)

    # Algorithm 2 from DDPM paper for p(x_{t-1} | x_t)
    model_mean = sqrt_recip_alphas_t.reshape(-1, 1, 1, 1, 1) * (
        x_t - betas_t.reshape(-1, 1, 1, 1, 1) * noise_pred /
        sqrt_one_minus_alphas_cumprod_t.reshape(-1, 1, 1, 1, 1)
    )

    if t[0] > 0:
        # Add noise only if not the final step
        posterior_variance_t = self.posterior_variance[t].reshape(-1, 1, 1, 1, 1)
        noise = torch.randn_like(x_t)
        return model_mean + torch.sqrt(posterior_variance_t) * noise
    else:
        # For t=0, don't add noise
        return model_mean

def p_sample_loop(self, model, shape, condition, n_steps=None, guidance_scale=1.0):
    """Full reverse diffusion sampling loop to generate new volumes.

    Args:
        model: Denoising model
        shape: Shape of volumes to generate [B, C, D, H, W]
        condition: Conditioning information
        n_steps: Optional number of steps (defaults to full process)
        guidance_scale: Scale for classifier-free guidance

    Returns:
        Generated volumes and intermediate steps
    """
    B = shape[0]
    if n_steps is None:
        n_steps = self.num_diffusion_steps

    # Start from pure noise
    x = torch.randn(shape, device=self.device)

    # Track intermediate generations for visualization
    intermediates = []

    # Progress bar for sampling
    progress_bar = tqdm(reversed(range(0, n_steps)), desc='Sampling', total=n_steps)

    for i in progress_bar:
        # For each batch element, use same timestep
        t = torch.full((B,), i, device=self.device, dtype=torch.long)

        # Single step of denoising
        x = self.p_sample(model, x, t, condition, guidance_scale=guidance_scale)

        # Save intermediate results every few steps
        if i % (n_steps // 10) == 0 or i == n_steps - 1:
            intermediates.append(x.detach().cpu())

    return x, intermediates

def ddim_sample(self, model, shape, condition, n_steps=50, guidance_scale=1.0, eta=0.0):
    """DDIM sampling for faster and deterministic generation.

    Args:
        model: Denoising model
        shape: Shape of volumes to generate [B, C, D, H, W]
        condition: Conditioning information
        n_steps: Number of DDIM steps (typically much less than DDPM)
        guidance_scale: Scale for classifier-free guidance
        eta: DDIM stochasticity parameter (0 = deterministic, 1 = DDPM-like)

    Returns:
        Generated volumes and intermediate steps
    """
    # Implementation follows DDIM paper: https://arxiv.org/abs/2010.02502
    B = shape[0]

    # Subsample original diffusion steps for DDIM
    step_indices = torch.linspace(0, self.num_diffusion_steps - 1, n_steps,
                                 dtype=torch.long, device=self.device)
    alphas_cumprod_sub = self.alphas_cumprod[step_indices]

    # Start from pure noise
    x = torch.randn(shape, device=self.device)

    intermediates = []

    # DDIM sampling loop
    progress_bar = tqdm(reversed(range(0, n_steps)), desc='DDIM Sampling', total=n_steps)

    for i in progress_bar:
        # Current timestep
        t = torch.full((B,), step_indices[i], device=self.device, dtype=torch.long)

        # Current α_t
        alpha_cumprod_t = alphas_cumprod_sub[i]

        # Previous α_{t-1} (or 1.0 for final step)
        alpha_cumprod_prev = alphas_cumprod_sub[i-1] if i > 0 else torch.ones_like(alpha_cumprod_t)

        # Model prediction with guidance if requested
        if guidance_scale > 1.0 and condition is not None:
            # Predict with and without conditioning
            noise_pred_cond = model(x, t, condition)
            noise_pred_uncond = model(x, t, None)

            # Apply classifier-free guidance
            predicted_noise = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
        else:
            predicted_noise = model(x, t, condition)

        # DDIM deterministic sampling formula
        # Extract x_0 from x_t and predicted noise
        sqrt_alpha_cumprod_t = torch.sqrt(alpha_cumprod_t)
        sqrt_one_minus_alpha_cumprod_t = torch.sqrt(1 - alpha_cumprod_t)

        # Predict x_0 from x_t and noise
        x_0_pred = (x - sqrt_one_minus_alpha_cumprod_t.reshape(-1, 1, 1, 1, 1) * predicted_noise) / \
                  sqrt_alpha_cumprod_t.reshape(-1, 1, 1, 1, 1)

        # For final step, just return the predicted x_0
        if i == 0:
            x = x_0_pred
        else:
            # DDIM update
            sqrt_alpha_cumprod_prev = torch.sqrt(alpha_cumprod_prev)
            sqrt_one_minus_alpha_cumprod_prev = torch.sqrt(1 - alpha_cumprod_prev)

            # Add noise based on eta parameter
            sigma_t = eta * torch.sqrt(
                (1 - alpha_cumprod_prev) / (1 - alpha_cumprod_t) *
                (1 - alpha_cumprod_t / alpha_cumprod_prev)
            )

            # Sample from distribution
            if eta > 0:
                noise = torch.randn_like(x)
                x = sqrt_alpha_cumprod_prev.reshape(-1, 1, 1, 1, 1) * x_0_pred + \
                    sqrt_one_minus_alpha_cumprod_prev.reshape(-1, 1, 1, 1, 1) * predicted_noise + \
                    sigma_t.reshape(-1, 1, 1, 1, 1) * noise
            else:
                # Deterministic (eta = 0)
                x = sqrt_alpha_cumprod_prev.reshape(-1, 1, 1, 1, 1) * x_0_pred + \
                    sqrt_one_minus_alpha_cumprod_prev.reshape(-1, 1, 1, 1, 1) * predicted_noise

        # Save intermediate results
        if i % (n_steps // 5) == 0 or i == n_steps - 1:
            intermediates.append(x.detach().cpu())

    return x, intermediates

# Add method implementations to the class
DiffusionModel3D.p_sample = p_sample
DiffusionModel3D.p_sample_loop = p_sample_loop
DiffusionModel3D.ddim_sample = ddim_sample

In [None]:
class TimeEmbedding(nn.Module):
    """Time step embedding module."""

    def __init__(self, dim):
        """Initialize time embedding module.

        Args:
            dim: Embedding dimension
        """
        super().__init__()
        self.dim = dim

        # First linear layer to expand timestep
        self.linear_1 = nn.Linear(dim, dim * 4)
        self.act = nn.SiLU()

        # Second linear layer
        self.linear_2 = nn.Linear(dim * 4, dim * 4)

    def forward(self, t):
        """Forward pass through time embedding.

        Args:
            t: Timesteps [B] or [B, 1]

        Returns:
            Time embeddings
        """
        # Ensure t has correct shape
        if len(t.shape) == 1:
            t = t.unsqueeze(-1)

        # First convert to embedding using sinusoidal positions
        half_dim = self.dim // 2
        emb = torch.log(torch.tensor(10000.0, device=t.device)) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t * emb
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)

        # Up-project embedding
        emb = self.linear_1(emb)
        emb = self.act(emb)
        emb = self.linear_2(emb)

        return emb


class ConditionalEmbedding(nn.Module):
    """Embedding for conditioning information (e.g., brain region)."""

    def __init__(self, num_classes, dim):
        """Initialize conditional embedding module.

        Args:
            num_classes: Number of classes to embed
            dim: Embedding dimension
        """
        super().__init__()
        self.dim = dim

        # Standard embedding layer
        self.embedding = nn.Embedding(num_classes, dim)

        # Projection layers
        self.linear_1 = nn.Linear(dim, dim * 4)
        self.act = nn.SiLU()
        self.linear_2 = nn.Linear(dim * 4, dim * 4)

    def forward(self, condition):
        """Forward pass through condition embedding.

        Args:
            condition: Conditioning class indices [B]

        Returns:
            Condition embeddings, or zeros for unconditional
        """
        if condition is None:
            # Return zeros for unconditional generation
            # Useful for classifier-free guidance
            batch_size = 1  # Fallback value
            return torch.zeros(batch_size, self.dim * 4, device=self.embedding.weight.device)

        # Get embeddings
        emb = self.embedding(condition)
        emb = self.linear_1(emb)
        emb = self.act(emb)
        emb = self.linear_2(emb)

        return emb


class SinusoidalPositionEmbeddings(nn.Module):
    """Sinusoidal position embeddings for 3D volumes."""

    def __init__(self, dim):
        """Initialize sinusoidal embedding module.

        Args:
            dim: Embedding dimension
        """
        super().__init__()
        self.dim = dim

    def forward(self, time):
        """Forward pass through sinusoidal embedding.

        Args:
            time: Time values [B]

        Returns:
            Time embeddings
        """
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=time.device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

In [None]:
class ResidualBlock3D(nn.Module):
    """Residual block for 3D U-Net with conditioning and time embedding."""

    def __init__(self, in_channels, out_channels, time_dim, use_attention=False, dropout=0.1, use_checkpoint=True):
        """Initialize residual block.

        Args:
            in_channels: Number of input channels
            out_channels: Number of output channels
            time_dim: Time embedding dimension
            use_attention: Whether to use self-attention
            dropout: Dropout probability
            use_checkpoint: Whether to use gradient checkpointing
        """
        super().__init__()

        self.use_checkpoint = use_checkpoint

        # First conv block
        self.norm1 = nn.GroupNorm(8, in_channels)
        self.act1 = nn.SiLU()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)

        # Time and condition projection
        self.time_proj = nn.Linear(time_dim, out_channels)

        # Second conv block
        self.norm2 = nn.GroupNorm(8, out_channels)
        self.act2 = nn.SiLU()
        self.dropout = nn.Dropout3d(dropout) if dropout > 0 else nn.Identity()
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)

        # Residual connection if channel dimensions don't match
        self.residual_conv = nn.Conv3d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()

        # Optional attention layer
        self.use_attention = use_attention
        if use_attention:
            self.attention = SelfAttention3D(out_channels)

    def _forward_impl(self, x, time_emb, cond_emb=None):
        """Forward implementation without checkpointing.

        Args:
            x: Input feature maps [B, C, D, H, W]
            time_emb: Time embedding [B, time_dim*4]
            cond_emb: Conditional embedding [B, time_dim*4]

        Returns:
            Output feature maps
        """
        # Residual branch
        residual = self.residual_conv(x)

        # Main branch
        h = self.norm1(x)
        h = self.act1(h)
        h = self.conv1(h)

        # Add time embedding
        time_emb = self.time_proj(time_emb[:, :time_emb.shape[1]//4])
        h = h + time_emb.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)

        # Add conditional embedding if provided
        if cond_emb is not None:
            # Assuming same dimension as time embedding for simplicity
            cond_emb = self.time_proj(cond_emb[:, :cond_emb.shape[1]//4])
            h = h + cond_emb.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)

        h = self.norm2(h)
        h = self.act2(h)
        h = self.dropout(h)
        h = self.conv2(h)

        # Apply attention if needed
        if self.use_attention:
            h = self.attention(h)

        # Add residual connection
        return h + residual

    def forward(self, x, time_emb, cond_emb=None):
        """Forward pass with optional gradient checkpointing.

        Args:
            x: Input feature maps [B, C, D, H, W]
            time_emb: Time embedding [B, time_dim*4]
            cond_emb: Conditional embedding [B, time_dim*4]

        Returns:
            Output feature maps
        """
        if self.use_checkpoint and self.training:
            # Use gradient checkpointing to save memory
            from torch.utils.checkpoint import checkpoint
            return checkpoint(
                lambda x, t, c: self._forward_impl(x, t, c),
                x, time_emb, cond_emb
            )
        else:
            return self._forward_impl(x, time_emb, cond_emb)


class SelfAttention3D(nn.Module):
    """Self-attention module for 3D U-Net."""

    def __init__(self, channels, attention_heads=4):
        """Initialize self-attention module.

        Args:
            channels: Number of input channels
            attention_heads: Number of attention heads
        """
        super().__init__()
        self.channels = channels
        self.attention_heads = attention_heads

        # Normalization and projections
        self.norm = nn.GroupNorm(8, channels)
        self.qkv = nn.Conv3d(channels, channels * 3, kernel_size=1)
        self.proj = nn.Conv3d(channels, channels, kernel_size=1)

        # Split multihead attention
        self.head_dim = channels // attention_heads
        self.scale = (self.head_dim) ** -0.5

    def forward(self, x):
        """Forward pass through self-attention.

        Args:
            x: Input feature maps [B, C, D, H, W]

        Returns:
            Attention-enhanced feature maps
        """
        B, C, D, H, W = x.shape

        # Normalize input
        h = self.norm(x)

        # Get q, k, v projections
        qkv = self.qkv(h)
        q, k, v = torch.chunk(qkv, 3, dim=1)

        # Reshape for multi-head attention
        q = q.reshape(B, self.attention_heads, self.head_dim, -1).permute(0, 1, 3, 2)
        k = k.reshape(B, self.attention_heads, self.head_dim, -1)
        v = v.reshape(B, self.attention_heads, self.head_dim, -1).permute(0, 1, 3, 2)

        # Compute attention scores
        attention = torch.matmul(q, k) * self.scale  # [B, heads, DHW, DHW]
        attention = F.softmax(attention, dim=-1)

        # Apply attention to values
        h = torch.matmul(attention, v).permute(0, 1, 3, 2)
        h = h.reshape(B, C, D, H, W)

        # Final projection
        h = self.proj(h)

        return h + x


class DownsampleBlock3D(nn.Module):
    """Downsample block for 3D U-Net."""

    def __init__(self, in_channels, out_channels, time_dim, use_attention=False, dropout=0.1, use_checkpoint=True):
        """Initialize downsample block.

        Args:
            in_channels: Number of input channels
            out_channels: Number of output channels
            time_dim: Time embedding dimension
            use_attention: Whether to use self-attention
            dropout: Dropout probability
            use_checkpoint: Whether to use gradient checkpointing
        """
        super().__init__()

        # Two residual blocks
        self.res1 = ResidualBlock3D(in_channels, out_channels, time_dim, use_attention, dropout, use_checkpoint)
        self.res2 = ResidualBlock3D(out_channels, out_channels, time_dim, use_attention, dropout, use_checkpoint)

        # Downsample using strided convolution
        self.downsample = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)

    def forward(self, x, time_emb, cond_emb=None):
        """Forward pass through downsample block.

        Args:
            x: Input feature maps
            time_emb: Time embedding
            cond_emb: Conditional embedding

        Returns:
            Downsampled features and skip connection
        """
        x = self.res1(x, time_emb, cond_emb)
        x = self.res2(x, time_emb, cond_emb)
        return self.downsample(x), x


class UpsampleBlock3D(nn.Module):
    """Upsample block for 3D U-Net."""

    def __init__(self, in_channels, out_channels, time_dim, use_attention=False, dropout=0.1, use_checkpoint=True):
        """Initialize upsample block.

        Args:
            in_channels: Number of input channels
            out_channels: Number of output channels
            time_dim: Time embedding dimension
            use_attention: Whether to use self-attention
            dropout: Dropout probability
            use_checkpoint: Whether to use gradient checkpointing
        """
        super().__init__()

        # Two residual blocks with skip connection input
        self.res1 = ResidualBlock3D(in_channels + out_channels, out_channels, time_dim, use_attention, dropout, use_checkpoint)
        self.res2 = ResidualBlock3D(out_channels, out_channels, time_dim, use_attention, dropout, use_checkpoint)

        # Upsample using transposed convolution
        self.upsample = nn.ConvTranspose3d(out_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x, skip, time_emb, cond_emb=None):
        """Forward pass through upsample block.

        Args:
            x: Input feature maps
            skip: Skip connection from encoder
            time_emb: Time embedding
            cond_emb: Conditional embedding

        Returns:
            Upsampled feature maps
        """
        x = self.upsample(x)

        # Handle potential size mismatches by padding/cropping
        if x.shape[2:] != skip.shape[2:]:
            x = F.interpolate(x, size=skip.shape[2:], mode='trilinear', align_corners=False)

        # Concatenate with skip connection
        x = torch.cat([x, skip], dim=1)

        x = self.res1(x, time_emb, cond_emb)
        x = self.res2(x, time_emb, cond_emb)
        return x

In [None]:
class UNet3D(nn.Module):
    """3D U-Net model for denoising diffusion process."""

    def __init__(
        self,
        config=None,
        in_channels=1,
        out_channels=1,
        time_dim=256,
        num_classes=None,
        base_dim=32,
        dim_mults=(1, 2, 4, 8),
        attention_resolutions=(8,),
        dropout=0.1,
        use_checkpoint=True
    ):
        """Initialize 3D U-Net model.

        Args:
            config: Optional configuration object
            in_channels: Number of input channels
            out_channels: Number of output channels
            time_dim: Time embedding dimension
            num_classes: Number of condition classes
            base_dim: Base channel dimension
            dim_mults: Channel multipliers at each resolution
            attention_resolutions: At which resolutions to apply attention
            dropout: Dropout probability
            use_checkpoint: Whether to use gradient checkpointing
        """
        super().__init__()

        # Use config if provided
        if config is not None:
            base_dim = config.base_channels
            dim_mults = config.channel_mults
            dropout = config.dropout
            use_checkpoint = config.use_checkpointing
            time_dim = config.time_dim

        # Dimensions at each resolution
        dims = [base_dim * m for m in dim_mults]

        # Initial projection
        self.init_conv = nn.Conv3d(in_channels, base_dim, kernel_size=3, padding=1)

        # Time embedding
        self.time_embedding = TimeEmbedding(time_dim)

        # Optional class conditioning
        self.has_class_conditioning = num_classes is not None and num_classes > 0
        if self.has_class_conditioning:
            self.class_embedding = ConditionalEmbedding(num_classes, time_dim)

        # Encoder part of U-Net (downsampling)
        self.downs = nn.ModuleList([])
        in_dim = base_dim

        # Current resolution relative to input
        current_res = 1
        resolutions = [current_res]

        for i, dim in enumerate(dims):
            # Determine if we use attention at this resolution
            use_attention = current_res in attention_resolutions

            # Add downsample block
            self.downs.append(
                DownsampleBlock3D(
                    in_dim, dim, time_dim,
                    use_attention=use_attention,
                    dropout=dropout,
                    use_checkpoint=use_checkpoint
                )
            )

            in_dim = dim
            current_res *= 2
            resolutions.append(current_res)

        # Middle part of U-Net (bottleneck)
        self.mid = nn.ModuleList([
            ResidualBlock3D(dims[-1], dims[-1], time_dim,
                           use_attention=True, dropout=dropout,
                           use_checkpoint=use_checkpoint),
            ResidualBlock3D(dims[-1], dims[-1], time_dim,
                           use_attention=True, dropout=dropout,
                           use_checkpoint=use_checkpoint)
        ])

        # Decoder part of U-Net (upsampling)
        self.ups = nn.ModuleList([])

        for i, dim in enumerate(reversed(dims)):
            # Determine if we use attention at this resolution
            use_attention = resolutions[-(i+2)] in attention_resolutions

            # Add upsample block
            self.ups.append(
                UpsampleBlock3D(
                    dim, dims[max(0, len(dims)-i-2)], time_dim,
                    use_attention=use_attention,
                    dropout=dropout,
                    use_checkpoint=use_checkpoint
                )
            )

        # Final output projection
        self.final_conv = nn.Sequential(
            nn.GroupNorm(8, base_dim),
            nn.SiLU(),
            nn.Conv3d(base_dim, out_channels, kernel_size=3, padding=1)
        )

        # Print model size estimate
        self._initialize_weights()
        self._print_model_size()

    def _initialize_weights(self):
        """Initialize weights for better training stability."""
        for m in self.modules():
            if isinstance(m, (nn.Conv3d, nn.Linear)):
                torch.nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    torch.nn.init.zeros_(m.bias)

    def _print_model_size(self):
        """Print model size information."""
        n_params = sum(p.numel() for p in self.parameters())
        n_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        logger.info(f"UNet3D model has {n_params:,} parameters ({n_trainable:,} trainable)")

    def forward(self, x, t, condition=None):
        """Forward pass through U-Net.

        Args:
            x: Input noisy volumes [B, C, D, H, W]
            t: Noise timesteps [B]
            condition: Optional conditioning information

        Returns:
            Predicted noise
        """
        # Initial feature extraction
        x = self.init_conv(x)

        # Time embedding
        t_emb = self.time_embedding(t)

        # Class embedding (if using)
        c_emb = None
        if self.has_class_conditioning:
            c_emb = self.class_embedding(condition)

        # Store skip connections
        skips = []

        # Encoder/Downsampling path
        for down in self.downs:
            x, skip = down(x, t_emb, c_emb)
            skips.append(skip)

        # Middle/Bottleneck
        for mid_block in self.mid:
            x = mid_block(x, t_emb, c_emb)

        # Decoder/Upsampling path with skip connections
        for up in self.ups:
            skip = skips.pop()
            x = up(x, skip, t_emb, c_emb)

        # Final output
        return self.final_conv(x)

In [None]:
class MemoryEfficientWrapper:
    """Memory optimization utilities for 3D diffusion models."""

    @staticmethod
    def enable_gradient_checkpointing(model):
        """Enable gradient checkpointing for a model to save memory.

        Args:
            model: PyTorch model
        """
        # Set all residual blocks to use checkpointing
        for module in model.modules():
            if hasattr(module, 'use_checkpoint'):
                module.use_checkpoint = True

        # Report memory usage
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()
            logger.info(f"CUDA memory allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB")
            logger.info(f"CUDA memory reserved: {torch.cuda.memory_reserved()/1e9:.2f} GB")

    @staticmethod
    def setup_mixed_precision():
        """Set up mixed precision training using PyTorch AMP.

        Returns:
            GradScaler for use in training loop
        """
        # Return scaler for use in training loop
        if torch.cuda.is_available():
            return torch.cuda.amp.GradScaler()
        else:
            logger.warning("CUDA not available, mixed precision not supported")
            return None

    @staticmethod
    def patch_based_forward(model, x, patch_size=(32, 32, 32), overlap=4):
        """Process a large volume using patch-based approach with overlap.

        Args:
            model: Model to use for processing
            x: Input volume [B, C, D, H, W]
            patch_size: Size of patches to process
            overlap: Overlap between patches

        Returns:
            Processed volume
        """
        if x.shape[2:] <= patch_size:
            # If volume is smaller than patch size, process directly
            return model(x)

        # Get volume dimensions
        B, C, D, H, W = x.shape
        pD, pH, pW = patch_size

        # Calculate number of patches in each dimension
        n_patches_d = max(1, (D - overlap) // (pD - overlap))
        n_patches_h = max(1, (H - overlap) // (pH - overlap))
        n_patches_w = max(1, (W - overlap) // (pW - overlap))

        # Adjust patch size to cover the volume with given number of patches
        effective_pD = (D - overlap) // n_patches_d + overlap
        effective_pH = (H - overlap) // n_patches_h + overlap
        effective_pW = (W - overlap) // n_patches_w + overlap

        # Initialize output volume
        output = torch.zeros_like(x)
        count = torch.zeros_like(x)

        # Process each patch
        for i in range(n_patches_d):
            d_start = i * (effective_pD - overlap)
            d_end = min(d_start + effective_pD, D)
            d_start = max(0, d_end - effective_pD)

            for j in range(n_patches_h):
                h_start = j * (effective_pH - overlap)
                h_end = min(h_start + effective_pH, H)
                h_start = max(0, h_end - effective_pH)

                for k in range(n_patches_w):
                    w_start = k * (effective_pW - overlap)
                    w_end = min(w_start + effective_pW, W)
                    w_start = max(0, w_end - effective_pW)

                    # Extract patch
                    patch = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end]

                    # Process patch
                    patch_output = model(patch)

                    # Create weight mask for blending (higher weight in center, lower at borders)
                    weight = torch.ones_like(patch)

                    # Update output and count
                    output[:, :, d_start:d_end, h_start:h_end, w_start:w_end] += patch_output * weight
                    count[:, :, d_start:d_end, h_start:h_end, w_start:w_end] += weight

        # Average overlapping regions
        output = output / count.clamp(min=1.0)

        return output

    @staticmethod
    def monitor_memory_usage():
        """Print current memory usage statistics.

        Returns:
            Allocated and reserved memory in GB
        """
        if not torch.cuda.is_available():
            logger.info("CUDA not available, skipping memory monitoring")
            return 0, 0

        # Reset peak stats
        torch.cuda.reset_peak_memory_stats()

        # Print current memory usage
        allocated = torch.cuda.memory_allocated() / 1e9
        reserved = torch.cuda.memory_reserved() / 1e9

        logger.info(f"CUDA memory allocated: {allocated:.2f} GB")
        logger.info(f"CUDA memory reserved: {reserved:.2f} GB")

        # Return values in case needed
        return allocated, reserved

In [None]:
def train_diffusion_model(
    diffusion,
    model,
    train_loader,
    val_loader,
    config=None,
    save_dir='checkpoints',
    resume_from=None
):
    """Train the 3D diffusion model.

    Args:
        diffusion: DiffusionModel3D instance
        model: UNet3D model
        train_loader: Training data loader
        val_loader: Validation data loader
        config: Configuration object
        save_dir: Directory to save checkpoints
        resume_from: Optional path to resume training from checkpoint

    Returns:
        Trained model and EMA model
    """
    # Use default config if not provided
    if config is None:
        config = BrainDiffusionConfig()

    # Extract training parameters from config
    num_epochs = config.epochs
    learning_rate = config.learning_rate
    weight_decay = config.weight_decay
    mixed_precision = config.mixed_precision
    ema_decay = config.ema_decay

    # Create optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=learning_rate,
        weight_decay=weight_decay
    )

    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=num_epochs, eta_min=learning_rate/10
    )

    # Set up mixed precision training if requested
    scaler = MemoryEfficientWrapper.setup_mixed_precision() if mixed_precision else None

    # Enable gradient checkpointing to save memory
    if config.use_checkpointing:
        MemoryEfficientWrapper.enable_gradient_checkpointing(model)

    # EMA model for better generation quality
    ema_model = copy.deepcopy(model).eval()

    # Track best validation loss
    best_val_loss = float('inf')
    start_epoch = 0

    # Resume from checkpoint if specified
    if resume_from and os.path.exists(resume_from):
        logger.info(f"Loading checkpoint from {resume_from}")
        checkpoint = torch.load(resume_from, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])

        if 'ema_model_state_dict' in checkpoint:
            ema_model.load_state_dict(checkpoint['ema_model_state_dict'])

        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

        if scaler is not None and 'scaler_state_dict' in checkpoint:
            scaler.load_state_dict(checkpoint['scaler_state_dict'])

        start_epoch = checkpoint['epoch'] + 1
        best_val_loss = checkpoint.get('best_val_loss', float('inf'))
        logger.info(f"Resuming from epoch {start_epoch}")

    # Training loop
    for epoch in range(start_epoch, num_epochs):
        model.train()
        train_loss = 0.0

        # Progress bar
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

        # Monitor memory before training
        logger.info(f"Memory before epoch {epoch+1}:")
        MemoryEfficientWrapper.monitor_memory_usage()

        for batch in progress_bar:
            optimizer.zero_grad()

            # Get batch data
            x = batch['volume'].to(device)
            condition = batch['region'].to(device) if 'region' in batch else None

            # Mixed precision training context
            amp_context = torch.cuda.amp.autocast() if mixed_precision and torch.cuda.is_available() else nullcontext()

            # Forward pass and loss calculation
            with amp_context:
                loss = diffusion.compute_loss(model, x, condition)

            if mixed_precision and scaler is not None:
                # Scale gradients and optimize with mixed precision
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
            else:
                # Regular backprop and optimization
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()

            # Update EMA model
            with torch.no_grad():
                for param, ema_param in zip(model.parameters(), ema_model.parameters()):
                    ema_param.data = ema_param.data * ema_decay + param.data * (1 - ema_decay)

            # Track loss
            train_loss += loss.item()
            progress_bar.set_postfix({"train_loss": loss.item()})

        # Calculate average training loss
        train_loss /= len(train_loader)

In [None]:
# Validation
        model.eval()
        val_loss = 0.0

        with torch.no_grad():
            for batch in val_loader:
                x = batch['volume'].to(device)
                condition = batch['region'].to(device) if 'region' in batch else None

                # Use mixed precision if enabled
                amp_context = torch.cuda.amp.autocast() if mixed_precision and torch.cuda.is_available() else nullcontext()

                with amp_context:
                    loss = diffusion.compute_loss(model, x, condition)

                val_loss += loss.item()

        val_loss /= max(1, len(val_loader))

        # Update learning rate
        scheduler.step()

        # Log results
        logger.info(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}")

        # Save checkpoint
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'ema_model_state_dict': ema_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'best_val_loss': best_val_loss
        }

        # Add scaler state if using mixed precision
        if scaler is not None:
            checkpoint['scaler_state_dict'] = scaler.state_dict()

        # Save latest checkpoint
        torch.save(checkpoint, os.path.join(save_dir, 'latest_checkpoint.pt'))

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(checkpoint, os.path.join(save_dir, 'best_model.pt'))
            logger.info(f"Saved best model with val loss: {val_loss:.6f}")

        # Save epoch checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1:
            torch.save(checkpoint, os.path.join(save_dir, f'epoch_{epoch+1}_checkpoint.pt'))

        # Generate and visualize samples every 10 epochs or at the end
        if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1:
            generate_and_visualize_samples(diffusion, ema_model, val_loader, save_dir, epoch)

        # Clear cache to prevent memory leaks
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    logger.info("Training completed!")
    return model, ema_model


def generate_and_visualize_samples(diffusion, model, val_loader, save_dir, epoch, num_samples=2):
    """Generate and save sample volumes during training.

    Args:
        diffusion: DiffusionModel3D instance
        model: UNet3D model
        val_loader: Validation data loader
        save_dir: Directory to save samples
        epoch: Current epoch
        num_samples: Number of samples to generate
    """
    model.eval()

    # Create samples directory
    samples_dir = os.path.join(save_dir, 'samples')
    os.makedirs(samples_dir, exist_ok=True)

    # Get a batch of validation data
    batch = next(iter(val_loader))
    x = batch['volume'].to(device)[:num_samples]
    condition = batch['region'].to(device)[:num_samples] if 'region' in batch else None

    # Generate samples using DDIM for efficiency
    logger.info("Generating samples...")
    with torch.no_grad():
        # Generate with the same conditions
        generated, _ = diffusion.ddim_sample(
            model,
            x.shape,
            condition=condition,
            n_steps=50,
            guidance_scale=2.0
        )

    # Convert to numpy for visualization
    x_cpu = x.cpu().numpy()
    generated_cpu = generated.cpu().numpy()

    # Visualize middle slices from each volume
    for i in range(num_samples):
        real_vol = x_cpu[i, 0]
        gen_vol = generated_cpu[i, 0]

        # Get middle slices in each dimension
        d_mid, h_mid, w_mid = [s // 2 for s in real_vol.shape]

        # Create slice visualizations
        fig, axes = plt.subplots(2, 3, figsize=(15, 8))

        # Real volume slices
        axes[0, 0].imshow(real_vol[d_mid, :, :], cmap='gray')
        axes[0, 0].set_title("Real - Axial")
        axes[0, 0].axis('off')

        axes[0, 1].imshow(real_vol[:, h_mid, :], cmap='gray')
        axes[0, 1].set_title("Real - Coronal")
        axes[0, 1].axis('off')

        axes[0, 2].imshow(real_vol[:, :, w_mid], cmap='gray')
        axes[0, 2].set_title("Real - Sagittal")
        axes[0, 2].axis('off')

        # Generated volume slices
        axes[1, 0].imshow(gen_vol[d_mid, :, :], cmap='gray')
        axes[1, 0].set_title("Generated - Axial")
        axes[1, 0].axis('off')

        axes[1, 1].imshow(gen_vol[:, h_mid, :], cmap='gray')
        axes[1, 1].set_title("Generated - Coronal")
        axes[1, 1].axis('off')

        axes[1, 2].imshow(gen_vol[:, :, w_mid], cmap='gray')
        axes[1, 2].set_title("Generated - Sagittal")
        axes[1, 2].axis('off')

        # Save figure
        region_label = condition[i].item() if condition is not None else "unknown"
        plt.suptitle(f"Sample {i+1}, Region: {region_label}")
        plt.tight_layout()
        plt.savefig(os.path.join(samples_dir, f'sample_{i+1}_epoch_{epoch+1}.png'))
        plt.close()

    logger.info(f"Samples saved to {samples_dir}")

IndentationError: unindent does not match any outer indentation level (<tokenize>, line 63)