# Image Deblurring

#### Authors: Peiyao Tao, Carolina Li
09/22/2025

CS 7180 Advanced Perception

The Purpose of this notebook is to compare the performance of three different image-deblurring models: U-Net, ViT, and Swin Transformer.

# Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import ViTModel
from transformers import SwinModel, SwinConfig
from PIL import Image
import torchvision.transforms.functional as TF
import os
import numpy as np
import matplotlib.pyplot as plt
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import segmentation_models_pytorch as smp
import shutil
from tqdm import tqdm
import cv2
import random
import torchvision.models as models
from torchvision.transforms import Normalize
import requests
import zipfile
from torch.amp import autocast, GradScaler
from torchmetrics import StructuralSimilarityIndexMeasure

# Global Configurations

In [None]:
ORIGINAL_DATA_DIR = "./DIV2K_train_HR"
CROPPED_DATA_DIR = "./cropped_dataset"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 64
EPOCHS = 100
LEARNING_RATE = 1e-4

print(f"Using device: {DEVICE}")

# Data Pre-porcessing

## Download Dataset

In [None]:
def download_and_unzip_div2k():
    """ Download and unzip the DIV2K dataset. """
    
    # Using the DIV2K dataset
    dataset_url = "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip"
    download_dir = "./"
    dataset_name = "DIV2K_train_HR"
    zip_path = os.path.join(download_dir, f"{dataset_name}.zip")
    
    # Skip download if dataset already exists
    if os.path.exists(os.path.join(download_dir, dataset_name)):
        print(f"Dataset '{dataset_name}' already exists. Skipping download.")
        return

    # Add a progress bar to the download
    print(f"Downloading {dataset_name}.zip...")
    try:
        response = requests.get(dataset_url, stream=True)
        response.raise_for_status()
        
        total_size_in_bytes = int(response.headers.get('content-length', 0))
        block_size = 1024
        
        progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
        with open(zip_path, 'wb') as file:
            for data in response.iter_content(block_size):
                progress_bar.update(len(data))
                file.write(data)
        progress_bar.close()

        if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
            print("ERROR, something went wrong during download.")
            return

    except requests.exceptions.RequestException as e:
        print(f"Error downloading file: {e}")
        return

    # Unzip the dataset
    print(f"\nUnzipping {zip_path}...")
    try:
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(download_dir)
        print("Unzipping complete.")
    except zipfile.BadZipFile:
        print("Error: The downloaded file is not a valid zip file or is corrupted.")
        return

    # Remove the zip file after extraction
    os.remove(zip_path)
    print(f"Cleaned up {zip_path}.")
    print("--- Dataset setup is complete! ---")

download_and_unzip_div2k()

## Train-Test Split

In [None]:
# Split the dataset into train and test
TRAIN_SHARP_DIR = os.path.join(ORIGINAL_DATA_DIR, "train", "sharp")
TEST_SHARP_DIR = os.path.join(ORIGINAL_DATA_DIR, "test", "sharp")

def organize_images():
    """ Organize images into train and test directories. """

    # Check for existance of train and test directories
    if os.path.exists(TRAIN_SHARP_DIR) or os.path.exists(TEST_SHARP_DIR):
        print("Train/Test directories already exist. Skipping organization.")
        return

    print("Organizing sharp images into train/test splits...")
    
    os.makedirs(TRAIN_SHARP_DIR, exist_ok=True)
    os.makedirs(TEST_SHARP_DIR, exist_ok=True)
    
    image_files = sorted([f for f in os.listdir(ORIGINAL_DATA_DIR) if f.endswith('.png')])
    
    if not image_files:
        print(f"Error: No .png images found in {ORIGINAL_DATA_DIR}. Please check the path.")
        return

    # last 100 images for testing
    train_files = image_files[:700]
    test_files = image_files[700:]

    # Move the files
    print(f"Moving {len(train_files)} images to {TRAIN_SHARP_DIR}...")
    for filename in tqdm(train_files):
        shutil.move(os.path.join(ORIGINAL_DATA_DIR, filename), os.path.join(TRAIN_SHARP_DIR, filename))
        
    print(f"Moving {len(test_files)} images to {TEST_SHARP_DIR}...")
    for filename in tqdm(test_files):
        shutil.move(os.path.join(ORIGINAL_DATA_DIR, filename), os.path.join(TEST_SHARP_DIR, filename))

    print("Image organization complete!")

organize_images()

## Apply Blurry Filter

In [None]:
# Linear motion blur
def generate_motion_blur_kernel(size, angle):
    """ Generate a motion blur kernel. """

    kernel = np.zeros((size, size))
    center = (size - 1) / 2
    radian_angle = np.deg2rad(angle)
    x_end = round(center + center * np.cos(radian_angle))
    y_end = round(center - center * np.sin(radian_angle))
    cv2.line(kernel, (round(center), round(center)), (int(x_end), int(y_end)), 1, thickness=1)
    return kernel / kernel.sum()

# Defocus blur
def generate_defocus_kernel(size, radius):
    """ Generate a defocus blur kernel. """

    kernel = np.zeros((size, size))
    center = size // 2
    cv2.circle(kernel, (center, center), radius, 1, thickness=-1)
    return kernel / kernel.sum()

def apply_blur_to_directory(source_dir, target_dir):
    """ Apply random blur to all images in a directory and save them to target directory. """

    if not os.path.exists(target_dir):
        os.makedirs(target_dir)

    image_files = sorted([f for f in os.listdir(source_dir) if f.endswith(('.png'))])
    
    print(f"Generating blurry images for {len(image_files)} sharp images...")
    
    # Apply random blur to each image
    for filename in tqdm(image_files):
        sharp_path = os.path.join(source_dir, filename)
        image = cv2.imread(sharp_path)
        
        blur_type = np.random.choice(['motion', 'gaussian', 'defocus'])
        
        if blur_type == 'motion':
            kernel_size = np.random.randint(15, 35)
            kernel_angle = np.random.randint(0, 180)
            kernel = generate_motion_blur_kernel(kernel_size, kernel_angle)
        elif blur_type == 'defocus':
            kernel_size = np.random.randint(15, 35)
            radius = np.random.randint(3, kernel_size // 2)
            kernel = generate_defocus_kernel(kernel_size, radius)
        else: # gaussian
            kernel_size = np.random.choice([15, 19, 21, 25])
            sigma = np.random.uniform(3, 8)
            kernel = cv2.getGaussianKernel(kernel_size, sigma)
            kernel = np.dot(kernel, kernel.T)
        
        blurry_image = cv2.filter2D(image, -1, kernel)
        blur_path = os.path.join(target_dir, filename)
        cv2.imwrite(blur_path, blurry_image)

    print("Blurry images generated successfully.")

# Generate blurry images for both train and test sets if they don't already exist
for split in ['train', 'test']:
    sharp_folder = os.path.join(ORIGINAL_DATA_DIR, split, 'sharp')
    blur_folder = os.path.join(ORIGINAL_DATA_DIR, split, 'blur')
    if not os.path.exists(blur_folder) or len(os.listdir(blur_folder)) == 0:
        print(f"Generating blurry images for '{split}' set...")
        apply_blur_to_directory(sharp_folder, blur_folder)
    else:
        print(f"Blurry images for '{split}' set already exist. Skipping generation.")

## Random Image Cropping

In [None]:
def create_crops(source_base, target_base, crop_size=(224, 224), crops_per_image=10):
    """ Create random crops from images to form our training dataset. """

    # Skips cropping if target directory already exists and is not empty
    if os.path.exists(target_base) and len(os.listdir(target_base)) > 0:
        print("Cropped dataset directory already exists and is not empty. Skipping cropping.")
        return
        
    print("Starting dataset pre-processing with random crops...")

    # Crop 10 images from each original image
    for split in ['train', 'test']:
        sharp_source_dir = os.path.join(source_base, split, 'sharp')
        blur_source_dir = os.path.join(source_base, split, 'blur')

        if not os.path.exists(sharp_source_dir):
            continue

        sharp_target_dir = os.path.join(target_base, split, 'sharp')
        blur_target_dir = os.path.join(target_base, split, 'blur')
        os.makedirs(sharp_target_dir, exist_ok=True)
        os.makedirs(blur_target_dir, exist_ok=True)

        print(f"Processing images in: {sharp_source_dir}")
        for filename in tqdm(os.listdir(sharp_source_dir)):
            sharp_img = Image.open(os.path.join(sharp_source_dir, filename)).convert("RGB")
            blur_img = Image.open(os.path.join(blur_source_dir, filename)).convert("RGB")

            img_w, img_h = sharp_img.size
            crop_h, crop_w = crop_size
            if img_w < crop_w or img_h < crop_h:
                continue

            for i in range(crops_per_image):
                top, left = random.randint(0, img_h - crop_h), random.randint(0, img_w - crop_w)
                sharp_cropped = sharp_img.crop((left, top, left + crop_w, top + crop_h))
                blur_cropped = blur_img.crop((left, top, left + crop_w, top + crop_h))
                
                base_name, ext = os.path.splitext(filename)
                new_filename = f"{base_name}_crop_{i}{ext}"
                sharp_cropped.save(os.path.join(sharp_target_dir, new_filename))
                blur_cropped.save(os.path.join(blur_target_dir, new_filename))
    print(f"Cropped images saved to: {target_base}")

create_crops(ORIGINAL_DATA_DIR, CROPPED_DATA_DIR)

## Create Validation Set

In [None]:
# Randomly split 10% of training datat into validation set
def create_validation_split(base_dir, ratio=0.1):
    """ Create a validation split from the training data. """

    train_dir = os.path.join(base_dir, "train")
    val_dir = os.path.join(base_dir, "validation")
    
    if os.path.exists(val_dir):
        print("Validation directory already exists. Skipping creation.")
        return

    print("Creating validation split...")
    os.makedirs(os.path.join(val_dir, "sharp"), exist_ok=True)
    os.makedirs(os.path.join(val_dir, "blur"), exist_ok=True)

    sharp_train_dir = os.path.join(train_dir, "sharp")
    image_files = sorted(os.listdir(sharp_train_dir))
    random.shuffle(image_files)
    num_val_images = int(len(image_files) * ratio)
    val_images = image_files[:num_val_images]

    print(f"Moving {num_val_images} images from train to validation set...")
    for filename in tqdm(val_images):
        shutil.move(os.path.join(train_dir, "sharp", filename), os.path.join(val_dir, "sharp", filename))
        shutil.move(os.path.join(train_dir, "blur", filename), os.path.join(val_dir, "blur", filename))
    print("Validation set created successfully.")

create_validation_split(CROPPED_DATA_DIR)

## Data Loader

In [None]:
# Dataset Class
class DeblurDataset(Dataset):
    def __init__(self, root_dir, split="train"):
        self.split_dir = os.path.join(root_dir, split)
        self.sharp_dir = os.path.join(self.split_dir, 'sharp')
        self.blur_dir = os.path.join(self.split_dir, 'blur')
        self.image_files = sorted(os.listdir(self.sharp_dir))

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        sharp_path = os.path.join(self.sharp_dir, img_name)
        blur_path = os.path.join(self.blur_dir, img_name)
        
        sharp_image = Image.open(sharp_path).convert("RGB")
        blur_image = Image.open(blur_path).convert("RGB")

        # Data augmentation by applying random horizontal flip
        if random.random() > 0.5:
            sharp_image = TF.hflip(sharp_image)
            blur_image = TF.hflip(blur_image)
        
        # Convert to tensor
        sharp_image = TF.to_tensor(sharp_image)
        blur_image = TF.to_tensor(blur_image)
        
        return blur_image, sharp_image
    
train_dataset = DeblurDataset(root_dir=CROPPED_DATA_DIR, split='train')
val_dataset = DeblurDataset(root_dir=CROPPED_DATA_DIR, split='validation')
test_dataset = DeblurDataset(root_dir=CROPPED_DATA_DIR, split='test')

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=0)

print(f"Training dataset loaded with {len(train_dataset)} images.")
print(f"Validation dataset loaded with {len(val_dataset)} images.")
print(f"Test dataset loaded with {len(test_dataset)} images.")

# Model Design & Training

## Custom Convolutional Block

In [None]:
# The ECA module and ConvBlock with ECA for ViT and Swin Transformer
class ECA(nn.Module):
    # For simplicity, we use a fixed kernel size of 3
    def __init__(self, channels, k_size=3):
        super().__init__()
        padding = (k_size - 1) // 2
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, h, w = x.size()
        y = self.avg_pool(x) 
        y = y.squeeze(-1).squeeze(-1) 
        y = y.unsqueeze(1)
        y = self.conv(y) 
        y = y.squeeze(1)
        y = self.sigmoid(y).unsqueeze(-1).unsqueeze(-1)
        return x * y.expand_as(x)
    
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.eca = ECA(out_channels)

    def forward(self, x):
        x = self.conv(x)
        x = self.eca(x)
        return x

## ViT Architecture

In [None]:
# ViT Decoder with skip connections
class ViTDecoderWithSkips(nn.Module):
    def __init__(self, in_features=768, num_classes=3):
        super().__init__()
        
        # Upsampling
        self.upconv1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(in_features, 256, kernel_size=3, padding=1)
        )
        self.conv1 = ConvBlock(256 + 768, 256)
        
        self.upconv2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(256, 128, kernel_size=3, padding=1)
        )
        self.conv2 = ConvBlock(128 + 768, 128)

        self.upconv3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(128, 64, kernel_size=3, padding=1)
        )
        self.conv3 = ConvBlock(64 + 768, 64)

        self.upconv4 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(64, 32, kernel_size=3, padding=1)
        )
        self.conv4 = ConvBlock(32, 32)
        
        self.final_conv = nn.Conv2d(32, num_classes, kernel_size=1)

    # Forward method with skip connections
    def forward(self, x, skips):
        patches = x[:, 1:, :]
        
        h = w = int(patches.shape[1]**0.5)
        x = patches.permute(0, 2, 1).contiguous().view(-1, 768, h, w)

        skip1 = skips[0][:, 1:, :].permute(0, 2, 1).contiguous().view(-1, 768, h, w)
        skip2 = skips[1][:, 1:, :].permute(0, 2, 1).contiguous().view(-1, 768, h, w)
        skip3 = skips[2][:, 1:, :].permute(0, 2, 1).contiguous().view(-1, 768, h, w)

        x = self.upconv1(x)
        skip1_up = F.interpolate(skip1, size=x.shape[2:], mode='bilinear', align_corners=True)
        x = torch.cat([x, skip1_up], dim=1)
        x = self.conv1(x)
        
        x = self.upconv2(x)
        skip2_up = F.interpolate(skip2, size=x.shape[2:], mode='bilinear', align_corners=True)
        x = torch.cat([x, skip2_up], dim=1)
        x = self.conv2(x)

        x = self.upconv3(x)
        skip3_up = F.interpolate(skip3, size=x.shape[2:], mode='bilinear', align_corners=True)
        x = torch.cat([x, skip3_up], dim=1)
        x = self.conv3(x)
        
        x = self.upconv4(x)
        x = self.conv4(x)
        
        x = self.final_conv(x)
        return torch.sigmoid(x)

# Full ViT Model
class ViTForDeblurring(nn.Module):
    def __init__(self, freeze_encoder=True, num_unfrozen_layers=4):
        super().__init__()
        self.encoder = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k', output_hidden_states=True)
        self.decoder = ViTDecoderWithSkips()

        for param in self.encoder.parameters():
            param.requires_grad = False
        
        if num_unfrozen_layers > 0:
            print(f"Fine-tuning ViT: Unfreezing the last {num_unfrozen_layers} transformer layers.")
            for layer in self.encoder.encoder.layer[-num_unfrozen_layers:]:
                for param in layer.parameters():
                    param.requires_grad = True
            
            for param in self.encoder.layernorm.parameters():
                 param.requires_grad = True

    # Forward method
    def forward(self, x):
        encoder_output = self.encoder(x)
        all_hidden_states = encoder_output.hidden_states
        skips = [
            all_hidden_states[3],
            all_hidden_states[6],
            all_hidden_states[9]
        ]
        final_state = all_hidden_states[-1]
        decoded_output = self.decoder(final_state, skips)
        return decoded_output

## Swin Transformer Architecture

In [None]:
# Swin Transformer Model with a U-Net style decoder
class SwinUnetDecoder(nn.Module):
    def __init__(self, num_classes=3):
        super().__init__()

        self.upconv1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(768, 384, kernel_size=3, padding=1)
        )
        self.upconv2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(384, 192, kernel_size=3, padding=1)
        )
        self.upconv3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(192, 96, kernel_size=3, padding=1)
        )
        self.upconv4 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(96, 64, kernel_size=3, padding=1)
        )

        self.conv1 = ConvBlock(in_channels=384 * 2, out_channels=384)
        self.conv2 = ConvBlock(in_channels=192 * 2, out_channels=192)
        self.conv3 = ConvBlock(in_channels=96 * 2, out_channels=96)
        self.conv4 = ConvBlock(in_channels=64, out_channels=32)

        self.final_conv = nn.Conv2d(32, num_classes, kernel_size=1)

    # Forward method
    def forward(self, encoder_features):
        e1, e2, e3, e4 = encoder_features

        x = self.upconv1(e4)
        x = torch.cat([x, e3], dim=1)
        x = self.conv1(x)

        x = self.upconv2(x)
        x = torch.cat([x, e2], dim=1)
        x = self.conv2(x)

        x = self.upconv3(x)
        x = torch.cat([x, e1], dim=1)
        x = self.conv3(x)
        
        x = self.upconv4(x)
        x = self.conv4(x)

        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
        x = self.final_conv(x)
        return x

# Full Swin Model
class SwinForDeblurring(nn.Module):
    def __init__(self):
        super().__init__()
        
        config = SwinConfig.from_pretrained('microsoft/swin-tiny-patch4-window7-224', output_hidden_states=True)
        self.encoder = SwinModel.from_pretrained('microsoft/swin-tiny-patch4-window7-224', config=config)
        self.decoder = SwinUnetDecoder()
        for param in self.encoder.parameters():
            param.requires_grad = True

    # Forward method
    def forward(self, x):
        hidden_states = self.encoder(x).hidden_states

        encoder_features = []
        for hs in hidden_states[:4]:
            b, n, c = hs.shape 
            h = int(n ** 0.5)  
            w = n // h

            hs_reshaped = hs.reshape(b, h, w, c).permute(0, 3, 1, 2)
            encoder_features.append(hs_reshaped)

        # Use residual learning
        residual = self.decoder(encoder_features)
        return torch.clamp(x + residual, 0, 1)

print("\n Swin-Unet model with improved upsampling defined successfully.")

## VGG Loss

In [None]:
# Perceptual Loss using VGG19
class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super(VGGPerceptualLoss, self).__init__()

        vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features[:18].eval()
        self.features = nn.Sequential(*vgg).to(DEVICE)

        self.normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

        for param in self.features.parameters():
            param.requires_grad = False
            
        self.l1 = nn.L1Loss()

    # Forward method
    def forward(self, input_img, target_img):
        input_norm = self.normalize(input_img)
        target_norm = self.normalize(target_img)

        input_features = self.features(input_norm)
        target_features = self.features(target_norm)

        return self.l1(input_features, target_features)

## Training Function

In [None]:
# Early Stopping Class with model checkpointing
class EarlyStopping:
    def __init__(self, patience=10, min_delta=0, best_model_path='best_model.pth'):
        self.patience = patience
        self.min_delta = min_delta
        self.best_model_path = best_model_path
        self.counter = 0
        self.best_loss = np.inf
        self.early_stop = False

    def __call__(self, val_loss, model):
        if self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            torch.save(model.state_dict(), self.best_model_path)
        else:
            self.counter += 1
            print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                print("Early stopping!")
                self.early_stop = True

In [None]:
# The main training function
def train_model(model, model_name, train_loader, val_loader, optimizer, epochs, device, lambda_vgg=0.05, lambda_ssim=2.0):
    # Loss functions and metrics
    criterion_l1 = nn.L1Loss()
    criterion_vgg = VGGPerceptualLoss(device).to(device)
    ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)

    model.to(device)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)
    best_model_path = f'{model_name}_best.pth'
    early_stopper = EarlyStopping(patience=10, best_model_path=best_model_path)

    scaler = GradScaler(device="cuda")
    
    print(f"--- Starting Training for {model_name} (Best model will be saved to {best_model_path}) ---")
    for epoch in range(epochs):
        model.train()
        running_l1_loss, running_vgg_loss, running_ssim_loss = 0.0, 0.0, 0.0
        for blurry_imgs, sharp_imgs in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]"):
            blurry_imgs, sharp_imgs = blurry_imgs.to(device), sharp_imgs.to(device)
            optimizer.zero_grad()
            
            with autocast(device_type="cuda"):
                outputs = model(blurry_imgs)
                loss_l1 = criterion_l1(outputs, sharp_imgs)
                loss_vgg = criterion_vgg(outputs, sharp_imgs)
                loss_ssim = 1 - ssim(outputs, sharp_imgs)
                total_loss = loss_l1 + (lambda_vgg * loss_vgg) + (lambda_ssim * loss_ssim)

            scaler.scale(total_loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_l1_loss += loss_l1.item()
            running_vgg_loss += loss_vgg.item()
            running_ssim_loss += loss_ssim.item()
        
        avg_train_l1 = running_l1_loss / len(train_loader)
        avg_train_vgg = running_vgg_loss / len(train_loader)
        avg_train_ssim = running_ssim_loss / len(train_loader)
        
        # Validation phase
        model.eval()
        val_loss_l1 = 0.0
        val_loss_vgg = 0.0
        val_loss_ssim = 0.0
        with torch.no_grad():
            for blurry_imgs, sharp_imgs in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]"):
                blurry_imgs, sharp_imgs = blurry_imgs.to(device), sharp_imgs.to(device)
                outputs = model(blurry_imgs)
                loss_l1 = criterion_l1(outputs, sharp_imgs)
                loss_vgg = criterion_vgg(outputs, sharp_imgs)
                loss_ssim = 1 - ssim(outputs, sharp_imgs)
                val_loss_l1 += loss_l1.item()
                val_loss_vgg += loss_vgg.item()
                val_loss_ssim += loss_ssim.item()

        avg_val_loss_l1 = val_loss_l1 / len(val_loader)
        avg_val_loss_vgg = val_loss_vgg / len(val_loader)
        avg_val_loss_ssim = val_loss_ssim / len(val_loader)
        avg_val_ssim_score = 1 - avg_val_loss_ssim
        
        current_learn_rate = optimizer.param_groups[0]['lr']
        print(f"Epoch [{epoch+1}/{epochs}] complete. Train L1: {avg_train_l1:.4f}, Train VGG: {avg_train_vgg:.4f}, "
            f"Train SSIM: {avg_train_ssim:.4f}, Val L1: {avg_val_loss_l1:.4f}, Val VGG: {avg_val_loss_vgg:.4f}, "
            f"Val SSIM: {avg_val_ssim_score:.4f}, LR: {current_learn_rate:.6f}")

        scheduler.step(avg_val_loss_l1)
        early_stopper(avg_val_loss_l1, avg_val_loss_vgg, avg_val_loss_ssim, model)
        if early_stopper.early_stop:
            print("Early stopping triggered. Exiting training loop.")
            break
    
    print(f"--- Finished Training for {model_name} ---")
    print(f"Loading best model weights from {best_model_path} (Val Loss: {early_stopper.best_loss:.6f})")
    model.load_state_dict(torch.load(best_model_path))
    return model

## Model Setup & Training

In [None]:
# U-Net
unet_model = smp.Unet(
    encoder_name="resnet50",
    encoder_weights="imagenet",
    in_channels=3,
    classes=3,
    decoder_use_batchnorm=True,
    decoder_attention_type='scse'
)

optimizer_unet = optim.AdamW(unet_model.parameters(), lr=LEARNING_RATE)
UNET_MODEL_NAME = "unet_best"
print("U-Net with ResNet50 backbone created successfully.")

# ViT
vit_model = ViTForDeblurring(freeze_encoder=False)
optimizer_vit = optim.Adam([
    {'params': vit_model.encoder.parameters(), 'lr': 1e-6},  # Lower LR for encoder
    {'params': vit_model.decoder.parameters(), 'lr': 1e-4}   # Higher LR for decoder
])
VIT_MODEL_NAME = "vit_final_best"
print("ViT-based model created successfully.")

# Swin Transformer
swin_model = SwinForDeblurring()
optimizer_swin = optim.AdamW(swin_model.parameters(), lr=LEARNING_RATE)
SWIN_MODEL_NAME = "swin_final_best"
print("Swin Transformer model created successfully.")

### UNet Training

In [None]:
trained_unet = train_model(unet_model, UNET_MODEL_NAME, train_loader, val_loader, optimizer_unet, EPOCHS, DEVICE)

### ViT Training

In [None]:
trained_vit = train_model(vit_model, VIT_MODEL_NAME, train_loader, val_loader, optimizer_vit, EPOCHS, DEVICE)

### Swin Transformer Training

In [None]:
trained_swin = train_model(swin_model, SWIN_MODEL_NAME, train_loader, val_loader, optimizer_swin, EPOCHS, DEVICE)

# Evaluation

## Evaluation On Cropped Images

In [None]:
def evaluate_on_cropped(unet_arch, vit_arch, swin_arch, unet_path, vit_path, swin_path, test_dataset, device, num_images=10):
    """ Evaluate models on cropped test dataset and display results with PSNR and SSIM metrics. """
    try:
        unet_arch.load_state_dict(torch.load(unet_path, weights_only=True))
        vit_arch.load_state_dict(torch.load(vit_path, weights_only=True))
        swin_arch.load_state_dict(torch.load(swin_path, weights_only=True))
        unet_arch.to(device).eval()
        vit_arch.to(device).eval()
        swin_arch.to(device).eval()
    except FileNotFoundError as e:
        print(f"Model weights not found: {e}. Skipping evaluation.")
        return

    indices = random.sample(range(len(test_dataset)), num_images)
    for i in indices:
        blurry_img, sharp_img = test_dataset[i]
        blurry_img_batch = blurry_img.unsqueeze(0).to(device)

        with torch.no_grad():
            unet_out = unet_arch(blurry_img_batch).squeeze().cpu()
            vit_out = vit_arch(blurry_img_batch).squeeze().cpu()
            swin_out = swin_arch(blurry_img_batch).squeeze().cpu()

        sharp_np = sharp_img.permute(1, 2, 0).numpy()
        blurry_np = blurry_img.permute(1, 2, 0).numpy()
        unet_np = unet_out.permute(1, 2, 0).numpy().clip(0, 1)
        vit_np = vit_out.permute(1, 2, 0).numpy().clip(0, 1)
        swin_np = swin_out.permute(1, 2, 0).numpy().clip(0, 1)

        psnr_unet, ssim_unet = psnr(sharp_np, unet_np, data_range=1.0), ssim(sharp_np, unet_np, multichannel=True, data_range=1.0, channel_axis=-1)
        psnr_vit, ssim_vit = psnr(sharp_np, vit_np, data_range=1.0), ssim(sharp_np, vit_np, multichannel=True, data_range=1.0, channel_axis=-1)
        psnr_swin, ssim_swin = psnr(sharp_np, swin_np, data_range=1.0), ssim(sharp_np, swin_np, multichannel=True, data_range=1.0, channel_axis=-1)
        
        print(f"--- Image Index {i} ---")
        print(f"U-Net --> PSNR: {psnr_unet:.2f}, SSIM: {ssim_unet:.4f}")
        print(f"ViT ----> PSNR: {psnr_vit:.2f}, SSIM: {ssim_vit:.4f}")
        print(f"Swin ---> PSNR: {psnr_swin:.2f}, SSIM: {ssim_swin:.4f}")

        # Plotting the results
        fig, axes = plt.subplots(1, 5, figsize=(25, 5))

        axes[0].imshow(blurry_np)
        axes[0].set_title("Blurry Input")

        axes[1].imshow(unet_np)
        axes[1].set_title(f"U-Net\nPSNR: {psnr_unet:.2f} | SSIM: {ssim_unet:.4f}")

        axes[2].imshow(vit_np)
        axes[2].set_title(f"ViT\nPSNR: {psnr_vit:.2f} | SSIM: {ssim_vit:.4f}")
        
        axes[3].imshow(swin_np)
        axes[3].set_title(f"Swin\nPSNR: {psnr_swin:.2f} | SSIM: {ssim_swin:.4f}")

        axes[4].imshow(sharp_np)
        axes[4].set_title("Ground Truth")

        for ax in axes:
            ax.axis('off')
        
        plt.tight_layout()
        plt.show()

unet_for_eval = smp.Unet("resnet50", in_channels=3, classes=3, decoder_use_batchnorm=True, decoder_attention_type='scse')
vit_for_eval = ViTForDeblurring(num_unfrozen_layers=4)
swin_for_eval = SwinForDeblurring()
evaluate_on_cropped(
    unet_for_eval, vit_for_eval, swin_for_eval,
    f"{UNET_MODEL_NAME}.pth", f"{VIT_MODEL_NAME}.pth", f"{SWIN_MODEL_NAME}.pth",
    test_dataset, DEVICE
)

## Evaluation On Full Image

In [None]:
def get_pad(h, w, patch_size, stride):
    """ Helper function to calculate padding needed for height and width. """
    pad_h = (stride - (h - patch_size) % stride) % stride
    pad_w = (stride - (w - patch_size) % stride) % stride
    return pad_h, pad_w

def predict_on_full_image(model, full_blurry_tensor, patch_size=224, overlap=32, device='cuda'):
    """ Predict on a full image by dividing it into overlapping patches. """
    model.eval()
    full_blurry_tensor = full_blurry_tensor.unsqueeze(0).to(device)
    _, _, h, w = full_blurry_tensor.shape
    stride = patch_size - overlap
    pad_h, pad_w = get_pad(h, w, patch_size, stride)
    padded_blurry = F.pad(full_blurry_tensor, (0, pad_w, 0, pad_h), 'reflect')
    _, _, padded_h, padded_w = padded_blurry.shape
    full_output = torch.zeros_like(padded_blurry)
    count_map = torch.zeros_like(padded_blurry)
    for i in range(0, padded_h - patch_size + 1, stride):
        for j in range(0, padded_w - patch_size + 1, stride):
            patch = padded_blurry[:, :, i:i+patch_size, j:j+patch_size]
            with torch.no_grad():
                deblurred_patch = model(patch)
            full_output[:, :, i:i+patch_size, j:j+patch_size] += deblurred_patch
            count_map[:, :, i:i+patch_size, j:j+patch_size] += 1
    final_output = (full_output / count_map)[:, :, :h, :w].squeeze(0).cpu()
    return final_output

def evaluate_on_full_images(unet_arch, vit_arch, swin_arch, unet_path, vit_path, swin_path, original_test_dir, device, num_images=100):
    """ Evaluate models on full-size test images and display results with PSNR and SSIM metrics. """
    try:
        unet_arch.load_state_dict(torch.load(unet_path, weights_only=True))
        vit_arch.load_state_dict(torch.load(vit_path, weights_only=True))
        swin_arch.load_state_dict(torch.load(swin_path, weights_only=True))
        unet_arch.to(device).eval()
        vit_arch.to(device).eval()
        swin_arch.to(device).eval()
    except FileNotFoundError as e:
        print(f"Model weights not found: {e}. Skipping evaluation.")
        return

    sharp_dir = os.path.join(original_test_dir, 'test', 'sharp')
    blur_dir = os.path.join(original_test_dir, 'test', 'blur')
    
    try:
        image_names = os.listdir(sharp_dir)
        selected_images = random.sample(image_names, min(num_images, len(image_names)))
    except FileNotFoundError:
        print(f"Test images not found in '{sharp_dir}'. Skipping.")
        return

    for image_name in selected_images:
        print(f"--- Processing Full Image: {image_name} ---")
        blurry_tensor = TF.to_tensor(Image.open(os.path.join(blur_dir, image_name)).convert("RGB"))
        sharp_tensor = TF.to_tensor(Image.open(os.path.join(sharp_dir, image_name)).convert("RGB"))
        
        unet_output = predict_on_full_image(unet_arch, blurry_tensor, device=device)
        vit_output = predict_on_full_image(vit_arch, blurry_tensor, device=device)
        swin_output = predict_on_full_image(swin_arch, blurry_tensor, device=device)

        sharp_np = sharp_tensor.permute(1, 2, 0).numpy()
        blurry_np = blurry_tensor.permute(1, 2, 0).numpy()
        unet_np = unet_output.permute(1, 2, 0).numpy().clip(0, 1)
        vit_np = vit_output.permute(1, 2, 0).numpy().clip(0, 1)
        swin_np = swin_output.permute(1, 2, 0).numpy().clip(0, 1)

        psnr_unet, ssim_unet = psnr(sharp_np, unet_np, data_range=1.0), ssim(sharp_np, unet_np, multichannel=True, data_range=1.0, channel_axis=-1)
        psnr_vit, ssim_vit = psnr(sharp_np, vit_np, data_range=1.0), ssim(sharp_np, vit_np, multichannel=True, data_range=1.0, channel_axis=-1)
        psnr_swin, ssim_swin = psnr(sharp_np, swin_np, data_range=1.0), ssim(sharp_np, swin_np, multichannel=True, data_range=1.0, channel_axis=-1)

        print(f"U-Net --> PSNR: {psnr_unet:.2f}, SSIM: {ssim_unet:.4f}")
        print(f"ViT ----> PSNR: {psnr_vit:.2f}, SSIM: {ssim_vit:.4f}")
        print(f"Swin ---> PSNR: {psnr_swin:.2f}, SSIM: {ssim_swin:.4f}")

        fig, axes = plt.subplots(2, 3, figsize=(24, 16))

        # --- Row 1: Input and Ground Truth ---
        axes[0, 0].imshow(blurry_np)
        axes[0, 0].set_title("Blurry Input")

        axes[0, 1].imshow(sharp_np)
        axes[0, 1].set_title("Ground Truth")
        
        # Turn off the unused subplot in the first row
        axes[0, 2].axis('off')

        # --- Row 2: Model Outputs ---
        axes[1, 0].imshow(unet_np)
        axes[1, 0].set_title(f"U-Net\nPSNR: {psnr_unet:.2f} | SSIM: {ssim_unet:.4f}")

        axes[1, 1].imshow(vit_np)
        axes[1, 1].set_title(f"ViT\nPSNR: {psnr_vit:.2f} | SSIM: {ssim_vit:.4f}")

        axes[1, 2].imshow(swin_np)
        axes[1, 2].set_title(f"Swin\nPSNR: {psnr_swin:.2f} | SSIM: {ssim_swin:.4f}")

        # Iterate over all subplots in the grid to turn off axis lines/ticks
        for ax in axes.flat:
            ax.axis('off')
        
        plt.tight_layout()
        plt.show()

unet_for_eval = smp.Unet("resnet50", in_channels=3, classes=3, decoder_use_batchnorm=True, decoder_attention_type='scse')
vit_for_eval = ViTForDeblurring(num_unfrozen_layers=4)
swin_for_eval = SwinForDeblurring()

evaluate_on_full_images(
    unet_for_eval, vit_for_eval, swin_for_eval,
    f"{UNET_MODEL_NAME}.pth", f"{VIT_MODEL_NAME}.pth", f"{SWIN_MODEL_NAME}.pth",
    ORIGINAL_DATA_DIR, DEVICE
)

## Compare With Average Scores

In [None]:
def calculate_average_metrics(unet_arch, vit_arch, swin_arch, unet_path, vit_path, swin_path, test_loader, device):
    """ Calculates average metrics over the test set. """

    try:
        unet_arch.load_state_dict(torch.load(unet_path, weights_only=True))
        vit_arch.load_state_dict(torch.load(vit_path, weights_only=True))
        swin_arch.load_state_dict(torch.load(swin_path, weights_only=True))
        
        unet_arch.to(device).eval()
        vit_arch.to(device).eval()
        swin_arch.to(device).eval()
    except FileNotFoundError as e:
        print(f"Model weights not found: {e}. Skipping final evaluation.")
        return

    unet_psnrs, unet_ssims = [], []
    vit_psnrs, vit_ssims = [], []
    swin_psnrs, swin_ssims = [], []

    print("Calculating average metrics over the entire test set...")
    with torch.no_grad():
        for blurry_imgs, sharp_imgs in tqdm(test_loader, desc="Evaluating Test Set"):
            blurry_imgs = blurry_imgs.to(device)
            
            # Process U-Net
            unet_outputs = unet_arch(blurry_imgs).cpu() 
            
            # Process ViT
            vit_outputs = vit_arch(blurry_imgs).cpu()

            # Process Swin
            swin_outputs = swin_arch(blurry_imgs).cpu()
            
            torch.cuda.empty_cache()

            # Iterate over the batch on the CPU
            for i in range(sharp_imgs.shape[0]):
                sharp_np = sharp_imgs[i].permute(1, 2, 0).numpy()
                
                unet_np = unet_outputs[i].permute(1, 2, 0).numpy().clip(0, 1)
                unet_psnrs.append(psnr(sharp_np, unet_np, data_range=1.0))
                unet_ssims.append(ssim(sharp_np, unet_np, multichannel=True, data_range=1.0, channel_axis=-1))

                vit_np = vit_outputs[i].permute(1, 2, 0).numpy().clip(0, 1)
                vit_psnrs.append(psnr(sharp_np, vit_np, data_range=1.0))
                vit_ssims.append(ssim(sharp_np, vit_np, multichannel=True, data_range=1.0, channel_axis=-1))

                swin_np = swin_outputs[i].permute(1, 2, 0).numpy().clip(0, 1)
                swin_psnrs.append(psnr(sharp_np, swin_np, data_range=1.0))
                swin_ssims.append(ssim(sharp_np, swin_np, multichannel=True, data_range=1.0, channel_axis=-1))

    print("\n--- Final Average Metrics ---")
    print(f"U-Net --> Average PSNR: {np.mean(unet_psnrs):.2f}, Average SSIM: {np.mean(unet_ssims):.4f}")
    print(f"ViT ----> Average PSNR: {np.mean(vit_psnrs):.2f}, Average SSIM: {np.mean(vit_ssims):.4f}")
    print(f"Swin ---> Average PSNR: {np.mean(swin_psnrs):.2f}, Average SSIM: {np.mean(swin_ssims):.4f}")


test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=0)

unet_for_eval = smp.Unet("resnet50", in_channels=3, classes=3, decoder_use_batchnorm=True, decoder_attention_type='scse')
vit_for_eval = ViTForDeblurring(num_unfrozen_layers=4)
swin_for_eval = SwinForDeblurring()


calculate_average_metrics(
    unet_for_eval, vit_for_eval, swin_for_eval,
    f"{UNET_MODEL_NAME}.pth", 
    f"{VIT_MODEL_NAME}.pth", 
    f"{SWIN_MODEL_NAME}.pth",
    test_loader, 
    DEVICE
)