# Super-Resolution Model Using Diffusion Model (Without SWIN Transformer)

This notebook presents the implementation of super-resolution model SR1 that upscales images from 64x64 to 128x128 using a Diffusion Probabilistic Model (DDPM). The model's forward process involves concatenating the low-resolution image with its bilinearly upscaled version, which is then passed through the network to upscale and produce a high-resolution output.

The notebook also includes the training routine for the model, which uses a combination of AdamW optimizer and a learning rate scheduler to minimize the mean squared error (MSE) between the predicted and target high-resolution images. The code saves the model's checkpoints periodically and generates comparison images that showcase the input, interpolated, upscaled, and target images side by side.




##### 1. Importing Libraries

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import os
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms, models
import numpy as np
import math

##### 2. Applying Transforms and Initializing Dataset Loaders

In [7]:
# Define transformations for 64x64 and 128x128 images
transform_64 = transforms.Compose([
    transforms.Grayscale(),  # Ensure images are single-channel
    transforms.Resize((64, 64)),  # Resize to 64x64
    transforms.ToTensor()  # Convert to PyTorch tensor
])

transform_128 = transforms.Compose([
    transforms.Grayscale(),  # Ensure images are single-channel
    transforms.Resize((128, 128)),  # Resize to 128x128
    transforms.ToTensor()  # Convert to PyTorch tensor
])

class MRNetUpscaleDataset(Dataset):
    def __init__(self, slice_dir, label_files, transform_64=None, transform_128=None):
        super().__init__()
        self.slice_dir = slice_dir
        self.transform_64 = transform_64
        self.transform_128 = transform_128

        self.labels_dict = {}
        for label_file in label_files:
            records = pd.read_csv(label_file, header=None, names=['id', 'label'])
            records['id'] = records['id'].map(lambda i: '0' * (4 - len(str(i))) + str(i))
            self.labels_dict.update(dict(zip(records['id'], records['label'])))

        # List all slice files
        self.slice_files = [os.path.join(slice_dir, fname) for fname in os.listdir(slice_dir) if fname.endswith('.png')]

    def __len__(self):
        return len(self.slice_files)

    def __getitem__(self, index):
        slice_path = self.slice_files[index]
        image = Image.open(slice_path)
        
        if self.transform_64:
            image_64 = self.transform_64(image)
        if self.transform_128:
            image_128 = self.transform_128(image)
        
        # Extract ID from filename to match with label
        slice_id = os.path.basename(slice_path).split('_')[1]

        if slice_id in self.labels_dict:
            label = self.labels_dict[slice_id]
            label = torch.FloatTensor([label])
        else:
            print(f"Label for ID {slice_id} not found in the CSV file.")
            label = torch.FloatTensor([0])  # or raise an error if preferred

        return {'data_64': image_64, 'data_128': image_128, 'label': label, 'id': slice_id}


# Initialize datasets and data loaders
root_dir = "Raw_Images"
train_slice_dir = os.path.join(root_dir, "train_slices_raw")
valid_slice_dir = os.path.join(root_dir, "valid_slices_raw")

train_label_files = [
    os.path.join(root_dir, "train-acl.csv"),
    os.path.join(root_dir, "train-abnormal.csv"),
    os.path.join(root_dir, "train-meniscus.csv")
]

valid_label_files = [
    os.path.join(root_dir, "valid-acl.csv"),
    os.path.join(root_dir, "valid-abnormal.csv"),
    os.path.join(root_dir, "valid-meniscus.csv")
]

train_dataset = MRNetUpscaleDataset(slice_dir=train_slice_dir, label_files=train_label_files, transform_64=transform_64, transform_128=transform_128)
valid_dataset = MRNetUpscaleDataset(slice_dir=valid_slice_dir, label_files=valid_label_files, transform_64=transform_64, transform_128=transform_128)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=16, shuffle=False)

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type == 'cuda':
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    print("Using CPU")


Using GPU: NVIDIA A100-SXM4-40GB


##### 3. Implementing the UNet Model

In [9]:
# Define the sinusoidal positional embedding
class SinusoidalPositionalEmbedding(nn.Module):
    def __init__(self, embedding_dim, max_len=10000):
        super(SinusoidalPositionalEmbedding, self).__init__()
        self.embedding_dim = embedding_dim
        self.max_len = max_len

    def forward(self, timesteps):
        half_dim = self.embedding_dim // 2
        emb = math.log(self.max_len) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
        emb = timesteps[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
        if self.embedding_dim % 2 == 1:
            emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=1)
        return emb

# Define the UNet model without the Swin Transformer
class SuperResUNet(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=128):
        super(SuperResUNet, self).__init__()

        self.encoder1 = self.conv_block(in_channels + emb_dim, 64)
        self.pool1 = nn.MaxPool2d(kernel_size=2)

        self.encoder2 = self.conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.encoder3 = self.conv_block(128, 256)
        self.pool3 = nn.MaxPool2d(kernel_size=2)

        self.bottleneck = self.conv_block(256, 512)

        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder3 = self.conv_block(512, 256)

        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder2 = self.conv_block(256, 128)

        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder1 = self.conv_block(128, 64)

        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

        self.timestep_embedding_layer = SinusoidalPositionalEmbedding(emb_dim)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, t):
        t_embed = self.timestep_embedding_layer(t)
        t_embed = t_embed.view(t.size(0), -1, 1, 1)
        t_embed = t_embed.repeat(1, 1, x.size(2), x.size(3))

        x = torch.cat((x, t_embed), dim=1)

        # Encoder
        enc1 = self.encoder1(x)
        enc1_pooled = self.pool1(enc1)

        enc2 = self.encoder2(enc1_pooled)
        enc2_pooled = self.pool2(enc2)

        enc3 = self.encoder3(enc2_pooled)
        enc3_pooled = self.pool3(enc3)

        # Bottleneck
        bottleneck = self.bottleneck(enc3_pooled)

        # Decoder
        upconv3 = self.upconv3(bottleneck)
        dec3 = torch.cat((upconv3, enc3), dim=1)
        dec3 = self.decoder3(dec3)

        upconv2 = self.upconv2(dec3)
        dec2 = torch.cat((upconv2, enc2), dim=1)
        dec2 = self.decoder2(dec2)

        upconv1 = self.upconv1(dec2)
        dec1 = torch.cat((upconv1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        final_output = self.final_conv(dec1)
        return final_output


##### 4. Implementing the DDPM Model

In [10]:
"""
DDPM implementation adapted from:
https://github.com/hojonathanho/diffusion/tree/master
"""

class SuperResDDPM(nn.Module):
    def __init__(self, model, num_timesteps, beta_start=0.00085, beta_end=0.0120):
        super(SuperResDDPM, self).__init__()
        self.model = model
        self.num_timesteps = num_timesteps

        betas = torch.linspace(beta_start, beta_end, num_timesteps)
        self.register_buffer('betas', betas)
        self.register_buffer('alphas', 1 - betas)
        self.register_buffer('alphas_cumprod', torch.cumprod(1 - betas, dim=0))
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1 - self.alphas_cumprod))

    def forward(self, z_t, t, low_res_image):
        # Concatenate low_res_image with z_t to condition the model
        low_res_upsampled = F.interpolate(low_res_image, scale_factor=2, mode='bilinear', align_corners=False)
        return self.model(torch.cat([z_t, low_res_upsampled], dim=1), t)

    def sample_timesteps(self, batch_size):
        return torch.randint(0, self.num_timesteps, (batch_size,)).to(device)

    def forward_diffusion(self, target_high_res_img, t, noise):
        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].unsqueeze(1).unsqueeze(1).unsqueeze(1)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(1).unsqueeze(1).unsqueeze(1)
        return sqrt_alphas_cumprod_t * target_high_res_img + sqrt_one_minus_alphas_cumprod_t * noise

    def p_losses(self, input_low_res_img, target_high_res_img, t):
        sqrt_alpha_t = self.sqrt_alphas_cumprod[t].unsqueeze(1).unsqueeze(1).unsqueeze(1)
        sqrt_one_minus_alpha_t = self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(1).unsqueeze(1).unsqueeze(1)

        input_high_res_img = F.interpolate(input_low_res_img, scale_factor=2, mode='bilinear', align_corners=False)

        noise = torch.randn_like(target_high_res_img)
        z_t = self.forward_diffusion(target_high_res_img, t, noise)

        predicted_noise = self.forward(z_t, t, input_low_res_img)

        return nn.MSELoss()(predicted_noise, noise)

    def sample(self, low_res_image):
        z_t = torch.randn_like(F.interpolate(low_res_image, scale_factor=2, mode='bilinear', align_corners=False))

        for t in reversed(range(self.num_timesteps)):
            t_tensor = torch.tensor([t], device=z_t.device).long()
            alpha_t = self.alphas[t]
            sqrt_alpha_t = torch.sqrt(alpha_t)
            sqrt_one_minus_alpha_t = torch.sqrt(1 - self.alphas_cumprod[t])
            beta_t = self.betas[t]

            predicted_noise = self.forward(z_t, t_tensor, low_res_image)

            z_t = (z_t - (1 - self.alphas[t]) * predicted_noise / sqrt_one_minus_alpha_t) / sqrt_alpha_t

            if t > 0:
                z_t += torch.randn_like(z_t) * torch.sqrt(beta_t)

        return z_t

    def p_sample(self, z, t, low_res_image):
        alpha_t = self.alphas[t]
        sqrt_alpha_t = torch.sqrt(alpha_t)
        sqrt_one_minus_alpha_t = torch.sqrt(1 - self.alphas_cumprod[t])
        beta_t = self.betas[t]
        predicted_noise = self.forward(z, t, low_res_image)

        z = (z - beta_t / sqrt_one_minus_alpha_t * predicted_noise) / sqrt_alpha_t
        return z

##### 5. Functions for Training and Saving Models

In [13]:

def save_model(ddpm, epoch, checkpoint_path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': ddpm.state_dict()
    }, checkpoint_path)
    print(f'Model checkpoint saved at {checkpoint_path}')

def load_model(ddpm, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    ddpm.load_state_dict(checkpoint['model_state_dict'])
    start_epoch = checkpoint['epoch']
    return ddpm, start_epoch

def compare_and_save_images(ddpm, valid_loader, epoch, save_dir='generated_images/training'):
    os.makedirs(save_dir, exist_ok=True)
    ddpm.eval()
    with torch.no_grad():
        for i, batch in enumerate(valid_loader):
            inputs = batch['data_64'].to(device)
            targets = batch['data_128'].to(device)
            
            # Prepare lists to store images for comparison
            inputs_list = []
            interpolated_list = []
            upscaled_list = []
            targets_list = []
            
            # Process each image individually
            for j in range(inputs.size(0)):
                input_image = inputs[j].unsqueeze(0)
                target_image = targets[j].unsqueeze(0)
                
                # Interpolate the low-resolution image
                interpolated_image = F.interpolate(input_image, scale_factor=2, mode='bilinear', align_corners=False)
                
                # Upscale the image using the DDPM model
                upscaled_image = ddpm.sample(input_image)
                
                inputs_list.append(input_image.cpu().numpy().squeeze())
                interpolated_list.append(interpolated_image.cpu().numpy().squeeze())
                upscaled_list.append(upscaled_image.cpu().numpy().squeeze())
                targets_list.append(target_image.cpu().numpy().squeeze())
                
            
            # Plot comparison for the first 5 images
            num_images = min(5, inputs.size(0))
            fig, axes = plt.subplots(num_images, 4, figsize=(20, 5 * num_images))
            
            for j in range(num_images):
                ax = axes[j, 0]
                ax.imshow(inputs_list[j], cmap='gray')
                ax.set_title(f'Input 64x64 Image {j+1}')
                ax.axis('off')

                ax = axes[j, 1]
                ax.imshow(interpolated_list[j], cmap='gray')
                ax.set_title(f'Interpolated 128x128 Image {j+1}')
                ax.axis('off')
                
                ax = axes[j, 2]
                ax.imshow(upscaled_list[j], cmap='gray')
                ax.set_title(f'Upscaled 128x128 Image {j+1}')
                ax.axis('off')

                ax = axes[j, 3]
                ax.imshow(targets_list[j], cmap='gray')
                ax.set_title(f'Target 128x128 Image {j+1}')
                ax.axis('off')
            
            save_path = os.path.join(save_dir, f'comparison_epoch_{epoch+1}.png')
            plt.savefig(save_path)
            plt.close()
            print(f'Comparison images saved at {save_path}')
            break  # Only process the first batch




##### 6. Training routine for the Diffusion Model

In [None]:
def train_diffusion_model(ddpm, train_loader, valid_loader, epochs=10, save_interval=10, checkpoint_path='cascadedddpm_checkpoint.pth'):
    optimizer = optim.AdamW(ddpm.parameters(), lr=1e-6, weight_decay=1e-2)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, verbose=True)
    start_epoch = 0

    if os.path.exists(checkpoint_path):
        ddpm, start_epoch = load_model(ddpm, checkpoint_path)
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
        print(f"Resuming training from epoch {start_epoch+1}")
    else:
        print(f"No checkpoint found at {checkpoint_path}, starting from scratch.")

    ddpm.to(device)

    train_losses = []
    valid_losses = []

    for epoch in range(start_epoch, epochs):
        ddpm.train()
        train_loss = 0
        for batch in train_loader:
            inputs = batch['data_64'].to(device)
            targets = batch['data_128'].to(device)
            optimizer.zero_grad()

            t = ddpm.sample_timesteps(inputs.size(0))
            loss = ddpm.p_losses(inputs, targets, t)

            loss.backward()

            # Gradient checking
            total_norm = 0
            for p in ddpm.parameters():
                if p.grad is not None:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
            total_norm = total_norm ** 0.5

            if total_norm > 1e3:  # Threshold for exploding gradients
                print(f"Warning: Exploding gradients detected at epoch {epoch+1}, total norm: {total_norm:.2f}")
                continue
            if total_norm < 1e-3:  # Threshold for vanishing gradients
                print(f"Warning: Vanishing gradients detected at epoch {epoch+1}, total norm: {total_norm:.2f}")
                continue

            optimizer.step()
            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {avg_train_loss:.4f}, Gradient Norm: {total_norm:.2f}')

        ddpm.eval()
        valid_loss = 0
        with torch.no_grad():
            for batch in valid_loader:
                inputs = batch['data_64'].to(device)
                targets = batch['data_128'].to(device)
                t = ddpm.sample_timesteps(inputs.size(0))
                loss = ddpm.p_losses(inputs, targets, t)

                valid_loss += loss.item()

        avg_valid_loss = valid_loss / len(valid_loader)
        valid_losses.append(avg_valid_loss)
        print(f'Epoch [{epoch+1}/{epochs}], Validation Loss: {avg_valid_loss:.4f}')

        # Step the scheduler with the validation loss
        scheduler.step(avg_valid_loss)

        if (epoch + 1) % save_interval == 0:
            save_model(ddpm, epoch, checkpoint_path)
            compare_and_save_images(ddpm, valid_loader, epoch)
            print(f'Model saved at epoch {epoch+1}')


##### 7. Running the Training 

In [12]:
# Initialize the UNet model and DDPM
in_channels = 2  # For grayscale images
out_channels = 1  # For grayscale images
emb_dim = 128
num_timesteps = 1000

unet = SuperResUNet(in_channels, out_channels, emb_dim).to(device)
ddpm = SuperResDDPM(unet, num_timesteps).to(device)

# Train the model
train_diffusion_model(ddpm, train_loader, valid_loader, epochs=100, save_interval=10, checkpoint_path='Model_Savepoints/cascadedddpm64NOSWIN_checkpoint.pth')

No checkpoint found at cascadedddpm64NOSWIN_checkpoint.pth, starting from scratch.
Epoch [1/100], Train Loss: 0.2075, Gradient Norm: 1.32
Epoch [1/100], Validation Loss: 0.0295
Epoch [2/100], Train Loss: 0.0174, Gradient Norm: 0.23
Epoch [2/100], Validation Loss: 0.0100
Epoch [3/100], Train Loss: 0.0103, Gradient Norm: 1.04
Epoch [3/100], Validation Loss: 0.0108
Epoch [4/100], Train Loss: 0.0079, Gradient Norm: 0.64
Epoch [4/100], Validation Loss: 0.0080
Epoch [5/100], Train Loss: 0.0075, Gradient Norm: 0.55
Epoch [5/100], Validation Loss: 0.0067
Epoch [6/100], Train Loss: 0.0064, Gradient Norm: 0.15
Epoch [6/100], Validation Loss: 0.0080
Epoch [7/100], Train Loss: 0.0069, Gradient Norm: 0.05
Epoch [7/100], Validation Loss: 0.0047
Epoch [8/100], Train Loss: 0.0057, Gradient Norm: 0.62
Epoch [8/100], Validation Loss: 0.0047
Epoch [9/100], Train Loss: 0.0061, Gradient Norm: 0.19
Epoch [9/100], Validation Loss: 0.0053
Epoch [10/100], Train Loss: 0.0056, Gradient Norm: 0.50
Epoch [10/100],

##### 8. Testing the model out

In [27]:
# Initialize the UNet model and DDPM
in_channels = 2  # For grayscale images
out_channels = 1  # For grayscale images
emb_dim = 128
num_timesteps = 1000

unet = SuperResUNet(in_channels, out_channels, emb_dim).to(device)
ddpm = SuperResDDPM(unet, num_timesteps).to(device)
def load_model_and_upscale_from_validation(ddpm, checkpoint_path, valid_loader, save_dir='generated_images/upscaled_from_validation'):
    # Load the checkpoint
    checkpoint = torch.load(checkpoint_path)
    ddpm.load_state_dict(checkpoint['model_state_dict'])
    ddpm.to(device)
    ddpm.eval()

    os.makedirs(save_dir, exist_ok=True)

    with torch.no_grad():
        for batch in valid_loader:
            inputs = batch['data_64'].to(device)
            targets = batch['data_128'].to(device)
            print(inputs.shape)
            # Select the first image in the batch
            input_image = inputs[0].unsqueeze(0)
            target_image = targets[0].unsqueeze(0)
            print(input_image.shape)
            # Interpolate the low-resolution image
            interpolated_image = F.interpolate(input_image, scale_factor=2, mode='bilinear', align_corners=False)
            print(interpolated_image.shape)
            # Upscale the image using the DDPM model
            upscaled_image = ddpm.sample(input_image)
            print(upscaled_image.shape)
            # Convert to numpy arrays for visualization
            input_image_np = input_image.cpu().numpy().squeeze()
            interpolated_image_np = interpolated_image.cpu().numpy().squeeze()
            upscaled_image_np = upscaled_image.cpu().numpy().squeeze()
            target_image_np = target_image.cpu().numpy().squeeze()
            
            # Plot and save the images for comparison
            fig, axes = plt.subplots(1, 4, figsize=(20, 5))
            
            axes[0].imshow(input_image_np, cmap='gray')
            axes[0].set_title('Input 64x64 Image')
            axes[0].axis('off')
            
            axes[1].imshow(interpolated_image_np, cmap='gray')
            axes[1].set_title('Interpolated 128x128 Image')
            axes[1].axis('off')
            
            axes[2].imshow(upscaled_image_np, cmap='gray')
            axes[2].set_title('Upscaled 128x128 Image')
            axes[2].axis('off')
            
            axes[3].imshow(target_image_np, cmap='gray')
            axes[3].set_title('Target 128x128 Image')
            axes[3].axis('off')
            
            save_path = os.path.join(save_dir, 'upscaled_comparison.png')
            plt.savefig(save_path)
            plt.close()
            print(f'Upscaled comparison image saved at {save_path}')
            break  # Only process the first batch


load_model_and_upscale_from_validation(ddpm, 'cascadedddpm64(2)100epochs_checkpoint.pth', valid_loader)


torch.Size([16, 1, 64, 64])
torch.Size([1, 1, 64, 64])
torch.Size([1, 1, 128, 128])
torch.Size([1, 1, 128, 128])
Upscaled comparison image saved at generated_images/upscaled_from_validation/upscaled_comparison.png
