# Overview

This notebook presents a sophisticated DDPM model designed for medical image synthesis, leveraging the power of the Swin Transformer in conjunction with a UNet architecture. The Swin Transformer, a hierarchical Vision Transformer that utilizes shifted windows, is integrated into the bottleneck of the UNet to enhance the model's capability to capture both local and global features effectively. This hybrid approach is particularly advantageous in handling the complex structures and fine details present in medical images.

This notebook guides you through:

- Loading the MRnet Dataset
- Implementation of the Swin Transformer-enhanced UNet
- Training and evaluation of the DDPM model

Overall, it demonstrates the potential of transformers in advancing medical image synthesis.

##### 1. Importing Libraries

In [15]:
# Import necessary libraries for building and training the model
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import os
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms, models
import numpy as np
import math


##### 2. Applying Transforms and Initializing Dataset Loaders

In [16]:
# Define transformations to preprocess the MRI images
transform = transforms.Compose([
    transforms.Grayscale(),  # Convert images to grayscale
    transforms.Resize((64, 64)),  # Resize images to 64x64
    transforms.ToTensor()  # Convert images to PyTorch tensors
])


class MRNetSliceDataset(Dataset):
    def __init__(self, slice_dir, label_files, transform=None):
        super().__init__()
        self.slice_dir = slice_dir
        self.transform = transform

        # Dictionary to store labels for each image ID
        self.labels_dict = {}
        for label_file in label_files:
            records = pd.read_csv(label_file, header=None, names=['id', 'label'])
            records['id'] = records['id'].map(lambda i: '0' * (4 - len(str(i))) + str(i))
            self.labels_dict.update(dict(zip(records['id'], records['label'])))

        # List all slice files in the directory
        self.slice_files = [os.path.join(slice_dir, fname) for fname in os.listdir(slice_dir) if fname.endswith('.png')]

    def __len__(self):
        return len(self.slice_files)

    def __getitem__(self, index):
        slice_path = self.slice_files[index]
        image = Image.open(slice_path)
        
        if self.transform:
            image = self.transform(image)
        
        # Extract ID from the filename to find the corresponding label
        slice_id = os.path.basename(slice_path).split('_')[1]

        if slice_id in self.labels_dict:
            label = self.labels_dict[slice_id]
            label = torch.FloatTensor([label])
        else:
            print(f"Label for ID {slice_id} not found in the CSV file.")
            label = torch.FloatTensor([0])  # Default label or handle missing labels as needed

        return {'data': image, 'label': label, 'id': slice_id}

# Initialize training and validation datasets and data loaders
root_dir = "Raw_Images"
train_slice_dir = os.path.join(root_dir, "train_slices_raw")
valid_slice_dir = os.path.join(root_dir, "valid_slices_raw")

train_label_files = [
    os.path.join(root_dir, "train-acl.csv"),
    os.path.join(root_dir, "train-abnormal.csv"),
    os.path.join(root_dir, "train-meniscus.csv")
]

valid_label_files = [
    os.path.join(root_dir, "valid-acl.csv"),
    os.path.join(root_dir, "valid-abnormal.csv"),
    os.path.join(root_dir, "valid-meniscus.csv")
]

train_dataset = MRNetSliceDataset(slice_dir=train_slice_dir, label_files=train_label_files, transform=transform)
valid_dataset = MRNetSliceDataset(slice_dir=valid_slice_dir, label_files=valid_label_files, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=16, shuffle=False)

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type == 'cuda':
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    print("Using CPU")


Using GPU: NVIDIA A100-SXM4-40GB


##### 3. Implementing Swin Transformer and Helper Functions

In [4]:
"""
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/abs/2103.14030>
https://github.com/microsoft/Swin-Transformer
"""

# DropPath (Stochastic Depth) module to implement drop path regularization
class DropPath(nn.Module):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if self.drop_prob == 0. or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()
        output = x.div(keep_prob) * random_tensor
        return output

# Helper functions to handle tuple and truncation
def to_2tuple(x):
    if isinstance(x, (tuple, list)):
        return x
    return (x, x)

def trunc_normal_(tensor, mean=0., std=1.):
    with torch.no_grad():
        size = tensor.shape
        tmp = tensor.new_empty(size + (4,)).normal_()
        valid = (tmp < 2) & (tmp > -2)
        ind = valid.max(-1, keepdim=True)[1]
        tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
        tensor.data.mul_(std).add_(mean)
        return tensor

# MLP module used within Swin Transformer
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

# Functions to partition and reverse windows in the Swin Transformer
def window_partition(x, window_size):
    B, C, H, W = x.shape
    x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
    windows = x.permute(0, 2, 4, 3, 5, 1).contiguous().view(-1, window_size, window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 5, 1, 3, 2, 4).contiguous().view(B, -1, H, W)
    return x

# Window-based multi-head self-attention (W-MSA) module
class WindowAttention(nn.Module):
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # Relative position bias table for all windows
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))

        # Get relative position index for each window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij'))
        coords_flatten = torch.flatten(coords, 1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += self.window_size[0] - 1
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1)
        attn = attn + relative_position_bias.unsqueeze(0).to(attn.dtype)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

# Swin Transformer block implementing the shifted window-based attention mechanism
class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * self.mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, drop=drop)

        if self.shift_size > 0:
            attn_mask = self.calculate_mask(self.input_resolution)
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)

    def calculate_mask(self, x_size):
        H, W = x_size
        img_mask = torch.zeros((1, H, W, 1))
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        mask_windows = window_partition(img_mask, self.window_size).view(-1, self.window_size * self.window_size)
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

        return attn_mask

    def forward(self, x):
        B, L, C = x.shape
        H, W = self.input_resolution

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        x_windows = window_partition(shifted_x, self.window_size)
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)

        if self.attn_mask is not None:
            attn_windows = self.attn(x_windows, mask=self.attn_mask.to(x.dtype))
        else:
            attn_windows = self.attn(x_windows)

        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)

        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        x = x.view(B, H * W, C)

        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x


##### 4. Implementing the Attention UNet Model

In [6]:
# Self-attention block for feature refinement
class SelfAttentionBlock(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttentionBlock, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, C, width, height = x.size()
        proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batch_size, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        attention = F.softmax(energy, dim=-1)
        proj_value = self.value_conv(x).view(batch_size, -1, width * height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batch_size, C, width, height)
        out = self.gamma * out + x
        return out

# Cross-attention block for the decoder to focus on relevant features
class CrossAttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(CrossAttentionBlock, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

# Sinusoidal positional embedding for timestep encoding in DDPM
class SinusoidalPositionalEmbedding(nn.Module):
    def __init__(self, embedding_dim, max_len=10000):
        super(SinusoidalPositionalEmbedding, self).__init__()
        self.embedding_dim = embedding_dim
        self.max_len = max_len

    def forward(self, timesteps):
        half_dim = self.embedding_dim // 2
        emb = math.log(self.max_len) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
        emb = timesteps[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
        if self.embedding_dim % 2 == 1:
            emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=1)
        return emb

# UNet with self-attention, cross-attention mechanisms and SWIN 
class AttentionUNet(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=128):
        super(AttentionUNet, self).__init__()

        self.encoder1 = self.conv_block(in_channels + emb_dim, 64)
        self.self_attention1 = SelfAttentionBlock(64)
        self.pool1 = nn.MaxPool2d(kernel_size=2)

        self.encoder2 = self.conv_block(64, 128)
        self.self_attention2 = SelfAttentionBlock(128)
        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.encoder3 = self.conv_block(128, 256)
        self.self_attention3 = SelfAttentionBlock(256)
        self.pool3 = nn.MaxPool2d(kernel_size=2)

        self.bottleneck = self.conv_block(256, 512)
        self.swin_block = SwinTransformerBlock(dim=512, input_resolution=(8, 8), num_heads=8, window_size=4)  

        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.cross_attention3 = CrossAttentionBlock(256, 256, 128)
        self.decoder3 = self.conv_block(512, 256)

        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.cross_attention2 = CrossAttentionBlock(128, 128, 64)
        self.decoder2 = self.conv_block(256, 128)

        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.cross_attention1 = CrossAttentionBlock(64, 64, 32)
        self.decoder1 = self.conv_block(128, 64)

        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

        self.timestep_embedding_layer = SinusoidalPositionalEmbedding(emb_dim)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, t):
        t_embed = self.timestep_embedding_layer(t)
        t_embed = t_embed.view(t.size(0), -1, 1, 1)
        t_embed = t_embed.repeat(1, 1, x.size(2), x.size(3))
        x = torch.cat((x, t_embed), dim=1)

        # Encoder
        enc1 = self.encoder1(x)
        enc1 = self.self_attention1(enc1)
        enc1_pooled = self.pool1(enc1)

        enc2 = self.encoder2(enc1_pooled)
        enc2 = self.self_attention2(enc2)
        enc2_pooled = self.pool2(enc2)

        enc3 = self.encoder3(enc2_pooled)
        enc3 = self.self_attention3(enc3)
        enc3_pooled = self.pool3(enc3)

        # Bottleneck
        bottleneck = self.bottleneck(enc3_pooled)
        B, C, H, W = bottleneck.shape
        bottleneck = bottleneck.view(B, H * W, C)
        bottleneck = self.swin_block(bottleneck)
        bottleneck = bottleneck.view(B, C, H, W)

        # Decoder
        upconv3 = self.upconv3(bottleneck)
        enc3 = self.cross_attention3(upconv3, enc3)
        dec3 = torch.cat((upconv3, enc3), dim=1)
        dec3 = self.decoder3(dec3)

        upconv2 = self.upconv2(dec3)
        enc2 = self.cross_attention2(upconv2, enc2)
        dec2 = torch.cat((upconv2, enc2), dim=1)
        dec2 = self.decoder2(dec2)

        upconv1 = self.upconv1(dec2)
        enc1 = self.cross_attention1(upconv1, enc1)
        dec1 = torch.cat((upconv1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        final_output = self.final_conv(dec1)
        return final_output


##### 5. Implementing the DDPM Model

In [7]:
"""
DDPM implementation adapted from:
https://github.com/hojonathanho/diffusion/tree/master
"""

class DDPM(nn.Module):
    def __init__(self, model, num_timesteps, latent_dim, beta_start=0.00085, beta_end=0.0120):
        super(DDPM, self).__init__()
        self.model = model
        self.num_timesteps = num_timesteps

        betas = torch.linspace(beta_start, beta_end, num_timesteps)
        self.register_buffer('betas', betas)
        self.register_buffer('alphas', 1 - betas)
        self.register_buffer('alphas_cumprod', torch.cumprod(1 - betas, dim=0))
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1 - self.alphas_cumprod))

        self.latent_dim = latent_dim

    def forward(self, z_t, t):
        return self.model(z_t, t)

    def sample_timesteps(self, batch_size):
        return torch.randint(0, self.num_timesteps, (batch_size,)).to(device)

    def forward_diffusion(self, z_0, t, noise=None):
        if noise is None:
            noise = torch.randn_like(z_0)
        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].unsqueeze(1).unsqueeze(1).unsqueeze(1)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(1).unsqueeze(1).unsqueeze(1)
        return sqrt_alphas_cumprod_t * z_0 + sqrt_one_minus_alphas_cumprod_t * noise

    def p_losses(self, z_0, t, noise=None):
        if noise is None:
            noise = torch.randn_like(z_0)
        z_t = self.forward_diffusion(z_0, t, noise)
        predicted_noise = self.forward(z_t, t)
        return nn.MSELoss()(noise, predicted_noise)

    def sample(self, shape):
        z_t = torch.randn(shape).to(device)
        for t in reversed(range(self.num_timesteps)):
            t_tensor = torch.tensor([t], device=z_t.device).long()
            alpha_t = self.alphas[t]
            sqrt_alpha_t = torch.sqrt(alpha_t)
            sqrt_one_minus_alpha_t = torch.sqrt(1 - self.alphas_cumprod[t])
            beta_t = self.betas[t]

            # Predict the noise
            predicted_noise = self.forward(z_t, t_tensor)

            # Remove the predicted noise
            z_t = (z_t - beta_t / sqrt_one_minus_alpha_t * predicted_noise) / sqrt_alpha_t

            # Add noise for non-final steps
            if t > 0:
                z_t += torch.randn_like(z_t) * torch.sqrt(beta_t)

        return z_t

    def p_sample(self, z, t):
        predicted_noise = self.forward(z, t)
        alpha_t = self.alphas[t]
        sqrt_alpha_t = torch.sqrt(alpha_t)
        sqrt_one_minus_alpha_t = torch.sqrt(1 - alpha_t)
        z = (z - predicted_noise * (1 - alpha_t) / sqrt_one_minus_alpha_t) / sqrt_alpha_t
        return z


##### 6. Functions for Training and Saving Models

In [None]:
# Save the model checkpoint
def save_model(ddpm, epoch, checkpoint_path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': ddpm.state_dict()
    }, checkpoint_path)
    print(f'Model checkpoint saved at {checkpoint_path}')

# Load the model checkpoint
def load_model(ddpm, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    ddpm.load_state_dict(checkpoint['model_state_dict'])
    start_epoch = checkpoint['epoch']
    return ddpm, start_epoch

# Function to generate and save images
def generate_and_save_images(ddpm, epoch, save_dir='generated_images/training'):
    os.makedirs(save_dir, exist_ok=True)
    latent_dim = 128
    shape = (1, 1, 64, 64)  # Change to 1 channel if using grayscale images

    # Ensure the model is in evaluation mode
    ddpm.eval()
    with torch.no_grad():
        # Sample from the DDPM in smaller batches to avoid memory issues
        samples = ddpm.sample(shape)

        # Convert to numpy and save images
        samples = samples.squeeze().cpu().detach().numpy()
        samples = (samples * 255).astype(np.uint8)
        save_path = os.path.join(save_dir, f'generated_image_epoch_{epoch+1}.png')
        Image.fromarray(samples, mode='L').save(save_path)
        print(f'Image saved at {save_path}')


##### 7. Training Routine for the Diffusion Model

In [11]:
# Train the model
def train_diffusion_model(ddpm, train_loader, valid_loader, epochs=10, save_interval=10, checkpoint_path='ddpm_checkpoint.pth'):
    optimizer = optim.AdamW(ddpm.parameters(), lr=1e-4, weight_decay=1e-2)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=15, verbose=True)
    start_epoch = 0
    
    # Check if a checkpoint exists
    if os.path.exists(checkpoint_path):
        ddpm, start_epoch = load_model(ddpm, checkpoint_path)
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
        print(f"Resuming training from epoch {start_epoch+1}")
    else:
        print(f"No checkpoint found at {checkpoint_path}, starting from scratch.")

    ddpm.to(device)

    for epoch in range(start_epoch, epochs):
        ddpm.train()
        train_loss = 0
        for batch in train_loader:
            inputs = batch['data'].to(device)
            optimizer.zero_grad()

            noise = torch.randn_like(inputs).to(device)
            t = ddpm.sample_timesteps(inputs.size(0))
            loss = ddpm.p_losses(inputs, t, noise)

            loss.backward()

            # Gradient checking
            total_norm = 0
            for p in ddpm.parameters():
                if p.grad is not None:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
            total_norm = total_norm ** 0.5

            if total_norm > 1e3:  # Threshold for exploding gradients
                print(f"Warning: Exploding gradients detected at epoch {epoch+1}, total norm: {total_norm:.2f}")
                continue
            if total_norm < 1e-3:  # Threshold for vanishing gradients
                print(f"Warning: Vanishing gradients detected at epoch {epoch+1}, total norm: {total_norm:.2f}")
                continue

            optimizer.step()
            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)
        print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {avg_train_loss:.4f}, Gradient Norm: {total_norm:.2f}')

        ddpm.eval()
        valid_loss = 0
        with torch.no_grad():
            for batch in valid_loader:
                inputs = batch['data'].to(device)
                noise = torch.randn_like(inputs).to(device)
                t = ddpm.sample_timesteps(inputs.size(0))
                loss = ddpm.p_losses(inputs, t, noise)

                valid_loss += loss.item()

        avg_valid_loss = valid_loss / len(valid_loader)
        print(f'Epoch [{epoch+1}/{epochs}], Validation Loss: {avg_valid_loss:.4f}')

         # Learning Rate Scheduler
        scheduler.step(avg_valid_loss)
        
        if (epoch + 1) % save_interval == 0:
            save_model(ddpm, epoch, checkpoint_path)
            generate_and_save_images(ddpm, epoch)
            print(f'Model saved at epoch {epoch+1}')

        # Special checkpoint save
        if (epoch + 1) % 200 == 0:
            special_checkpoint_path = checkpoint_path.replace(".pth", f"_{epoch+1}.pth")
            save_model(ddpm, epoch, special_checkpoint_path)
            print(f'Special checkpoint saved at epoch {epoch+1}')
            
    print("Training completed.")


##### 8. Training the model

In [12]:
# Initialize the DDPM model
in_channels = 1  # For grayscale images
out_channels = 1  # For grayscale images
emb_dim = 128
num_timesteps = 2000
latent_dim = 128

unet = AttentionUNet(in_channels, out_channels, emb_dim).to(device)
ddpm = DDPM(unet, num_timesteps, latent_dim).to(device)

# Train the model
train_diffusion_model(ddpm, train_loader, valid_loader, epochs=600, save_interval=10, checkpoint_path="Model_Savepoints/ddpm_checkpoint.pth")


No checkpoint found at swinddpm64MRnet_checkpoint.pth, starting from scratch.
Epoch [1/600], Train Loss: 0.1768, Gradient Norm: 0.90
Epoch [1/600], Validation Loss: 0.0393
Epoch [2/600], Train Loss: 0.0284, Gradient Norm: 0.52
Epoch [2/600], Validation Loss: 0.0266
Epoch [3/600], Train Loss: 0.0217, Gradient Norm: 0.68
Epoch [3/600], Validation Loss: 0.0515
Epoch [4/600], Train Loss: 0.0178, Gradient Norm: 1.47
Epoch [4/600], Validation Loss: 0.0167
Epoch [5/600], Train Loss: 0.0157, Gradient Norm: 0.38
Epoch [5/600], Validation Loss: 0.0153
Epoch [6/600], Train Loss: 0.0154, Gradient Norm: 1.06
Epoch [6/600], Validation Loss: 0.0172
Epoch [7/600], Train Loss: 0.0142, Gradient Norm: 0.96
Epoch [7/600], Validation Loss: 0.0145
Epoch [8/600], Train Loss: 0.0127, Gradient Norm: 1.09
Epoch [8/600], Validation Loss: 0.0125
Epoch [9/600], Train Loss: 0.0139, Gradient Norm: 0.95
Epoch [9/600], Validation Loss: 0.0101
Epoch [10/600], Train Loss: 0.0120, Gradient Norm: 0.31
Epoch [10/600], Vali

##### 9. Generating Images

In [14]:
# Function to generate and save images after training
def generate_and_save_images_post_training(ddpm, num_images=10, save_dir='generated_images/DDPM_images'):
    os.makedirs(save_dir, exist_ok=True)
    sample_shape = (1, 1, 64, 64)  # Generate 1 image at a time, 1 channel, 64x64 images

    # Create a figure to plot images
    fig, axes = plt.subplots(10, 5, figsize=(15, 30))  # Adjust the layout for 10 images in a 2x5 grid

    ddpm.eval()
    with torch.no_grad():
        for i in range(num_images):
            samples = ddpm.sample(sample_shape)
            # Convert to numpy
            samples = samples.cpu().numpy().squeeze()
            
            # Plot the image in the grid
            ax = axes[i // 5, i % 5]
            ax.imshow(samples, cmap='gray')
            ax.axis('off')
    
    # Save the figure
    save_path = os.path.join(save_dir, 'generated_images_table.png')
    plt.savefig(save_path)
    plt.close()

    print(f'{num_images} images saved in a table format at {save_path}')

# Load the trained model
checkpoint_path = "Model_Savepoints/ddpm_checkpoint_600.pth" 
ddpm, start_epoch = load_model(ddpm, checkpoint_path)

# Generate and save images after training
generate_and_save_images_post_training(ddpm, num_images=50, save_dir='generated_images/epoch600')

50 images saved in a table format at generated_images/epoch600/generated_images_table.png
