# Super-Resolution Model Using Diffusion Model (With Swin Transformer in the UNet)

This notebook presents the implementation of super-resolution model SR2 that upscales images from 128x128 to 256x256 using a Diffusion Probabilistic Model (DDPM) which has a SWIN transformer in its UNet. The model's forward process involves concatenating the low-resolution image with its bilinearly upscaled version, which is then passed through the network to upscale and produce a high-resolution output.

The notebook also includes the training routine for the model, which uses a combination of AdamW optimizer and a learning rate scheduler to minimize the mean squared error (MSE) between the predicted and target high-resolution images. The code saves the model's checkpoints periodically and generates comparison images that showcase the input, interpolated, upscaled, and target images side by side.



##### 1. Importing Libraries

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import os
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms, models
import numpy as np
import math

##### 2. Applying Transforms and Initializing Dataset Loaders

In [2]:


# Define transformations for 128x128 and 256x256 images
transform_128 = transforms.Compose([
    transforms.Grayscale(),  # Ensure images are single-channel
    transforms.Resize((128, 128)),  # Resize to 128x128
    transforms.ToTensor()  # Convert to PyTorch tensor
])

transform_256 = transforms.Compose([
    transforms.Grayscale(),  # Ensure images are single-channel
    transforms.Resize((256, 256)),  # Resize to 256x256
    transforms.ToTensor()  # Convert to PyTorch tensor
])

class MRNetUpscaleDataset(Dataset):
    def __init__(self, slice_dir, label_files, transform_128=None, transform_256=None):
        super().__init__()
        self.slice_dir = slice_dir
        self.transform_128 = transform_128
        self.transform_256 = transform_256

        self.labels_dict = {}
        for label_file in label_files:
            records = pd.read_csv(label_file, header=None, names=['id', 'label'])
            records['id'] = records['id'].map(lambda i: '0' * (4 - len(str(i))) + str(i))
            self.labels_dict.update(dict(zip(records['id'], records['label'])))

        # List all slice files
        self.slice_files = [os.path.join(slice_dir, fname) for fname in os.listdir(slice_dir) if fname.endswith('.png')]

    def __len__(self):
        return len(self.slice_files)

    def __getitem__(self, index):
        slice_path = self.slice_files[index]
        image = Image.open(slice_path)
        
        if self.transform_128:
            image_128 = self.transform_128(image)
        if self.transform_256:
            image_256 = self.transform_256(image)
        
        # Extract ID from filename to match with label
        # Filename format: abnormal_0000_slice_0 -> extract "0000" as ID
        slice_id = os.path.basename(slice_path).split('_')[1]

        if slice_id in self.labels_dict:
            label = self.labels_dict[slice_id]
            label = torch.FloatTensor([label])
        else:
            print(f"Label for ID {slice_id} not found in the CSV file.")
            label = torch.FloatTensor([0])  # or raise an error if preferred

        return {'data_128': image_128, 'data_256': image_256, 'label': label, 'id': slice_id}


# Initialize datasets and data loaders
root_dir = "Raw_Images"
train_slice_dir = os.path.join(root_dir, "train_slices_raw")
valid_slice_dir = os.path.join(root_dir, "valid_slices_raw")

train_label_files = [
    os.path.join(root_dir, "train-acl.csv"),
    os.path.join(root_dir, "train-abnormal.csv"),
    os.path.join(root_dir, "train-meniscus.csv")
]

valid_label_files = [
    os.path.join(root_dir, "valid-acl.csv"),
    os.path.join(root_dir, "valid-abnormal.csv"),
    os.path.join(root_dir, "valid-meniscus.csv")
]

train_dataset = MRNetUpscaleDataset(slice_dir=train_slice_dir, label_files=train_label_files, transform_128=transform_128, transform_256=transform_256)
valid_dataset = MRNetUpscaleDataset(slice_dir=valid_slice_dir, label_files=valid_label_files, transform_128=transform_128, transform_256=transform_256)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=16, shuffle=False)
# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type == 'cuda':
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    print("Using CPU")


Using GPU: NVIDIA A100-SXM4-40GB


##### 3. Implementing Swin Transformer and Helper Functions

In [3]:
"""
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/abs/2103.14030>
https://github.com/microsoft/Swin-Transformer
"""

# DropPath (Stochastic Depth) module to implement drop path regularization
class DropPath(nn.Module):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if self.drop_prob == 0. or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()
        output = x.div(keep_prob) * random_tensor
        return output

# Helper functions to handle tuple and truncation
def to_2tuple(x):
    if isinstance(x, (tuple, list)):
        return x
    return (x, x)

def trunc_normal_(tensor, mean=0., std=1.):
    with torch.no_grad():
        size = tensor.shape
        tmp = tensor.new_empty(size + (4,)).normal_()
        valid = (tmp < 2) & (tmp > -2)
        ind = valid.max(-1, keepdim=True)[1]
        tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
        tensor.data.mul_(std).add_(mean)
        return tensor

# MLP module used within Swin Transformer
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

# Functions to partition and reverse windows in the Swin Transformer
def window_partition(x, window_size):
    B, C, H, W = x.shape
    x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
    windows = x.permute(0, 2, 4, 3, 5, 1).contiguous().view(-1, window_size, window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 5, 1, 3, 2, 4).contiguous().view(B, -1, H, W)
    return x

# Window-based multi-head self-attention (W-MSA) module
class WindowAttention(nn.Module):
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # Relative position bias table for all windows
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))

        # Get relative position index for each window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij'))
        coords_flatten = torch.flatten(coords, 1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += self.window_size[0] - 1
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1)
        attn = attn + relative_position_bias.unsqueeze(0).to(attn.dtype)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

# Swin Transformer block implementing the shifted window-based attention mechanism
class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * self.mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, drop=drop)

        if self.shift_size > 0:
            attn_mask = self.calculate_mask(self.input_resolution)
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)

    def calculate_mask(self, x_size):
        H, W = x_size
        img_mask = torch.zeros((1, H, W, 1))
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        mask_windows = window_partition(img_mask, self.window_size).view(-1, self.window_size * self.window_size)
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

        return attn_mask

    def forward(self, x):
        B, L, C = x.shape
        H, W = self.input_resolution

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        x_windows = window_partition(shifted_x, self.window_size)
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)

        if self.attn_mask is not None:
            attn_windows = self.attn(x_windows, mask=self.attn_mask.to(x.dtype))
        else:
            attn_windows = self.attn(x_windows)

        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)

        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        x = x.view(B, H * W, C)

        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x


##### 4. Implementing the UNet Model

In [4]:
# Define the sinusoidal positional embedding
class SinusoidalPositionalEmbedding(nn.Module):
    def __init__(self, embedding_dim, max_len=10000):
        super(SinusoidalPositionalEmbedding, self).__init__()
        self.embedding_dim = embedding_dim
        self.max_len = max_len

    def forward(self, timesteps):
        half_dim = self.embedding_dim // 2
        emb = math.log(self.max_len) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
        emb = timesteps[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
        if self.embedding_dim % 2 == 1:
            emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=1)
        return emb

# Define the UNet model with attention
class SuperResUNet(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=128):
        super(SuperResUNet, self).__init__()

        self.encoder1 = self.conv_block(in_channels + emb_dim, 64)
        self.pool1 = nn.MaxPool2d(kernel_size=2)

        self.encoder2 = self.conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.encoder3 = self.conv_block(128, 256)
        self.pool3 = nn.MaxPool2d(kernel_size=2)

        self.bottleneck = self.conv_block(256, 512)
        self.swin_block = SwinTransformerBlock(dim=512, input_resolution=(32, 32), num_heads=8, window_size=4)
        
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder3 = self.conv_block(512, 256)

        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder2 = self.conv_block(256, 128)
        
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder1 = self.conv_block(128, 64)

        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

        self.timestep_embedding_layer = SinusoidalPositionalEmbedding(emb_dim)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, t):
        t_embed = self.timestep_embedding_layer(t)
        t_embed = t_embed.view(t.size(0), -1, 1, 1)
        t_embed = t_embed.repeat(1, 1, x.size(2), x.size(3))

        x = torch.cat((x, t_embed), dim=1)

        # Encoder
        enc1 = self.encoder1(x)
        enc1_pooled = self.pool1(enc1)

        enc2 = self.encoder2(enc1_pooled)
        enc2_pooled = self.pool2(enc2)

        enc3 = self.encoder3(enc2_pooled)
        enc3_pooled = self.pool3(enc3)

        # Bottleneck
        bottleneck = self.bottleneck(enc3_pooled)
        B, C, H, W = bottleneck.shape
        bottleneck = bottleneck.view(B, H * W, C)
        bottleneck = self.swin_block(bottleneck)
        bottleneck = bottleneck.view(B, C, H, W)

        # Decoder
        upconv3 = self.upconv3(bottleneck)
        dec3 = torch.cat((upconv3, enc3), dim=1)
        dec3 = self.decoder3(dec3)

        upconv2 = self.upconv2(dec3)
        dec2 = torch.cat((upconv2, enc2), dim=1)
        dec2 = self.decoder2(dec2)

        upconv1 = self.upconv1(dec2)
        dec1 = torch.cat((upconv1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        final_output = self.final_conv(dec1)
        return final_output


##### 5. Implementing the DDPM Model

In [5]:
"""
DDPM implementation adapted from:
https://github.com/hojonathanho/diffusion/tree/master
"""

class SuperResDDPM(nn.Module):
    def __init__(self, model, num_timesteps, beta_start=0.00085, beta_end=0.0120):
        super(SuperResDDPM, self).__init__()
        self.model = model
        self.num_timesteps = num_timesteps

        betas = torch.linspace(beta_start, beta_end, num_timesteps)
        self.register_buffer('betas', betas)
        self.register_buffer('alphas', 1 - betas)
        self.register_buffer('alphas_cumprod', torch.cumprod(1 - betas, dim=0))
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1 - self.alphas_cumprod))

    def forward(self, z_t, t, low_res_image):
        # Concatenate low_res_image with z_t to condition the model
        low_res_upsampled = F.interpolate(low_res_image, scale_factor=2, mode='bilinear', align_corners=False)
        return self.model(torch.cat([z_t, low_res_upsampled], dim=1), t)

    def sample_timesteps(self, batch_size):
        return torch.randint(0, self.num_timesteps, (batch_size,)).to(device)

    def forward_diffusion(self, target_high_res_img, t, noise):
        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].unsqueeze(1).unsqueeze(1).unsqueeze(1)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(1).unsqueeze(1).unsqueeze(1)
        return sqrt_alphas_cumprod_t * target_high_res_img + sqrt_one_minus_alphas_cumprod_t * noise

    def p_losses(self, input_low_res_img, target_high_res_img, t):
        sqrt_alpha_t = self.sqrt_alphas_cumprod[t].unsqueeze(1).unsqueeze(1).unsqueeze(1)
        sqrt_one_minus_alpha_t = self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(1).unsqueeze(1).unsqueeze(1)

        input_high_res_img = F.interpolate(input_low_res_img, scale_factor=2, mode='bilinear', align_corners=False)

        noise = torch.randn_like(target_high_res_img)
        z_t = self.forward_diffusion(target_high_res_img, t, noise)

        predicted_noise = self.forward(z_t, t, input_low_res_img)

        return nn.MSELoss()(predicted_noise, noise)

    def sample(self, low_res_image):
        z_t = torch.randn_like(F.interpolate(low_res_image, scale_factor=2, mode='bilinear', align_corners=False))

        for t in reversed(range(self.num_timesteps)):
            t_tensor = torch.tensor([t], device=z_t.device).long()
            alpha_t = self.alphas[t]
            sqrt_alpha_t = torch.sqrt(alpha_t)
            sqrt_one_minus_alpha_t = torch.sqrt(1 - self.alphas_cumprod[t])
            beta_t = self.betas[t]

            predicted_noise = self.forward(z_t, t_tensor, low_res_image)

            z_t = (z_t - (1 - self.alphas[t]) * predicted_noise / sqrt_one_minus_alpha_t) / sqrt_alpha_t

            if t > 0:
                z_t += torch.randn_like(z_t) * torch.sqrt(beta_t)

        return z_t

    def p_sample(self, z, t, low_res_image):
        alpha_t = self.alphas[t]
        sqrt_alpha_t = torch.sqrt(alpha_t)
        sqrt_one_minus_alpha_t = torch.sqrt(1 - self.alphas_cumprod[t])
        beta_t = self.betas[t]
        predicted_noise = self.forward(z, t, low_res_image)

        z = (z - beta_t / sqrt_one_minus_alpha_t * predicted_noise) / sqrt_alpha_t
        return z

##### 6. Functions for Training and Saving Models

In [8]:
# Save the model checkpoint
def save_model(ddpm, epoch, checkpoint_path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': ddpm.state_dict()
    }, checkpoint_path)
    print(f'Model checkpoint saved at {checkpoint_path}')

# Load the model checkpoint
def load_model(ddpm, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    ddpm.load_state_dict(checkpoint['model_state_dict'])
    start_epoch = checkpoint['epoch']
    return ddpm, start_epoch

# Function to generate and save comparison images
def compare_and_save_images(ddpm, valid_loader, epoch, save_dir='generated_images/training'):
    os.makedirs(save_dir, exist_ok=True)
    ddpm.eval()
    with torch.no_grad():
        for i, batch in enumerate(valid_loader):
            inputs = batch['data_128'].to(device)
            targets = batch['data_256'].to(device)
            
            # Prepare lists to store images for comparison
            inputs_list = []
            interpolated_list = []
            upscaled_list = []
            targets_list = []
            
            # Process each image individually
            for j in range(inputs.size(0)):
                input_image = inputs[j].unsqueeze(0)
                target_image = targets[j].unsqueeze(0)
                
                # Interpolate the low-resolution image
                interpolated_image = F.interpolate(input_image, scale_factor=2, mode='bilinear', align_corners=False)
                
                # Upscale the image using the DDPM model
                upscaled_image = ddpm.sample(input_image)
                
                inputs_list.append(input_image.cpu().numpy().squeeze())
                interpolated_list.append(interpolated_image.cpu().numpy().squeeze())
                upscaled_list.append(upscaled_image.cpu().numpy().squeeze())
                targets_list.append(target_image.cpu().numpy().squeeze())
                
            
            # Plot comparison for the first 5 images
            num_images = min(5, inputs.size(0))
            fig, axes = plt.subplots(num_images, 4, figsize=(20, 5 * num_images))
            
            for j in range(num_images):
                ax = axes[j, 0]
                ax.imshow(inputs_list[j], cmap='gray')
                ax.set_title(f'Input 128x128 Image {j+1}')
                ax.axis('off')

                ax = axes[j, 1]
                ax.imshow(interpolated_list[j], cmap='gray')
                ax.set_title(f'Interpolated 256x256 Image {j+1}')
                ax.axis('off')
                
                ax = axes[j, 2]
                ax.imshow(upscaled_list[j], cmap='gray')
                ax.set_title(f'Upscaled 256x256 Image {j+1}')
                ax.axis('off')

                ax = axes[j, 3]
                ax.imshow(targets_list[j], cmap='gray')
                ax.set_title(f'Target 256x256 Image {j+1}')
                ax.axis('off')
            
            save_path = os.path.join(save_dir, f'comparison_epoch_{epoch+1}.png')
            plt.savefig(save_path)
            plt.close()
            print(f'Comparison images saved at {save_path}')
            break  # Only process the first batch



##### 7. Training routine for the Diffusion Model

In [None]:
def train_diffusion_model(ddpm, train_loader, valid_loader, epochs=10, save_interval=10, checkpoint_path='cascadedddpm_checkpoint.pth'):
    optimizer = optim.AdamW(ddpm.parameters(), lr=5e-5, weight_decay=1e-2)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=15, verbose=True)
    start_epoch = 0

    if os.path.exists(checkpoint_path):
        ddpm, start_epoch = load_model(ddpm, checkpoint_path)
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
        print(f"Resuming training from epoch {start_epoch+1}")
    else:
        print(f"No checkpoint found at {checkpoint_path}, starting from scratch.")

    ddpm.to(device)

    train_losses = []
    valid_losses = []

    for epoch in range(start_epoch, epochs):
        ddpm.train()
        train_loss = 0
        for batch in train_loader:
            inputs = batch['data_128'].to(device)
            targets = batch['data_256'].to(device)
            optimizer.zero_grad()

            t = ddpm.sample_timesteps(inputs.size(0))
            loss = ddpm.p_losses(inputs, targets, t)

            loss.backward()

            # Gradient checking
            total_norm = 0
            for p in ddpm.parameters():
                if p.grad is not None:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
            total_norm = total_norm ** 0.5

            if total_norm > 1e3:  # Threshold for exploding gradients
                print(f"Warning: Exploding gradients detected at epoch {epoch+1}, total norm: {total_norm:.2f}")
                continue
            if total_norm < 1e-3:  # Threshold for vanishing gradients
                print(f"Warning: Vanishing gradients detected at epoch {epoch+1}, total norm: {total_norm:.2f}")
                continue

            optimizer.step()
            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {avg_train_loss:.4f}, Gradient Norm: {total_norm:.2f}')

        ddpm.eval()
        valid_loss = 0
        with torch.no_grad():
            for batch in valid_loader:
                inputs = batch['data_128'].to(device)
                targets = batch['data_256'].to(device)
                t = ddpm.sample_timesteps(inputs.size(0))
                loss = ddpm.p_losses(inputs, targets, t)

                valid_loss += loss.item()

        avg_valid_loss = valid_loss / len(valid_loader)
        valid_losses.append(avg_valid_loss)
        print(f'Epoch [{epoch+1}/{epochs}], Validation Loss: {avg_valid_loss:.4f}')

        # Step the scheduler with the validation loss
        scheduler.step(avg_valid_loss)

        if (epoch + 1) % save_interval == 0:
            save_model(ddpm, epoch, checkpoint_path)
            compare_and_save_images(ddpm, valid_loader, epoch)
            print(f'Model saved at epoch {epoch+1}')
        
        # Special checkpoint save
        # if (epoch + 1) % 100 == 0:
        #     special_checkpoint_path = checkpoint_path.replace(".pth", f"_{epoch+1}.pth")
        #     save_model(ddpm, epoch, special_checkpoint_path)
        #     print(f'Special checkpoint saved at epoch {epoch+1}')

    print("Training completed.")

##### 8. Training the model

In [9]:

# Initialize the UNet model and DDPM
in_channels = 2  # For grayscale images
out_channels = 1  # For grayscale images
emb_dim = 128
num_timesteps = 2000


unet = SuperResUNet(in_channels, out_channels, emb_dim).to(device)
ddpm = SuperResDDPM(unet, num_timesteps).to(device)

# Train the model
train_diffusion_model(ddpm, train_loader, valid_loader, epochs=300, save_interval=10, checkpoint_path='Model_Savepoints/cascadedddpm128added4_checkpoint.pth')



No checkpoint found at cascadedddpm128added4_checkpoint.pth, starting from scratch.
Epoch [1/300], Train Loss: 0.3114, Gradient Norm: 1.28
Epoch [1/300], Validation Loss: 0.0543
Epoch [2/300], Train Loss: 0.0355, Gradient Norm: 1.64
Epoch [2/300], Validation Loss: 0.0301
Epoch [3/300], Train Loss: 0.0217, Gradient Norm: 1.64
Epoch [3/300], Validation Loss: 0.0145
Epoch [4/300], Train Loss: 0.0142, Gradient Norm: 1.39
Epoch [4/300], Validation Loss: 0.0121
Epoch [5/300], Train Loss: 0.0110, Gradient Norm: 0.34
Epoch [5/300], Validation Loss: 0.0085
Epoch [6/300], Train Loss: 0.0090, Gradient Norm: 0.65
Epoch [6/300], Validation Loss: 0.0056
Epoch [7/300], Train Loss: 0.0103, Gradient Norm: 0.56
Epoch [7/300], Validation Loss: 0.0065
Epoch [8/300], Train Loss: 0.0064, Gradient Norm: 0.47
Epoch [8/300], Validation Loss: 0.0092
Epoch [9/300], Train Loss: 0.0053, Gradient Norm: 0.12
Epoch [9/300], Validation Loss: 0.0059
Epoch [10/300], Train Loss: 0.0062, Gradient Norm: 0.11
Epoch [10/300]