In [None]:
# %pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# %pip install transformers timm einops opencv-python matplotlib scikit-image --upgrade
# %pip install segmentation-models-pytorch --upgrade

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import ViTModel
from PIL import Image
import torchvision.transforms.functional as TF
import os
import numpy as np
import matplotlib.pyplot as plt
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import segmentation_models_pytorch as smp
import shutil
from tqdm import tqdm
import cv2
import random
import torchvision.models as models
from torchvision.transforms import Normalize
import requests
import zipfile

In [2]:
ORIGINAL_DATA_DIR = "./DIV2K_train_HR"
CROPPED_DATA_DIR = "./cropped_dataset"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 64
EPOCHS = 100
LEARNING_RATE = 1e-4

print(f"Using device: {DEVICE}")

Using device: cuda


In [3]:
def download_and_unzip_div2k():
    """
    Downloads and unzips the DIV2K high-resolution training dataset.
    """
    # --- Configuration ---
    dataset_url = "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip"
    download_dir = "./"
    dataset_name = "DIV2K_train_HR"
    zip_path = os.path.join(download_dir, f"{dataset_name}.zip")
    
    # --- Check if dataset already exists ---
    if os.path.exists(os.path.join(download_dir, dataset_name)):
        print(f"Dataset '{dataset_name}' already exists. Skipping download.")
        return

    # --- 1. Download the file with a progress bar ---
    print(f"Downloading {dataset_name}.zip... (This may take a while)")
    try:
        response = requests.get(dataset_url, stream=True)
        response.raise_for_status() # Raise an exception for bad status codes
        
        total_size_in_bytes = int(response.headers.get('content-length', 0))
        block_size = 1024 # 1 Kibibyte
        
        progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
        with open(zip_path, 'wb') as file:
            for data in response.iter_content(block_size):
                progress_bar.update(len(data))
                file.write(data)
        progress_bar.close()

        if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
            print("ERROR, something went wrong during download.")
            return

    except requests.exceptions.RequestException as e:
        print(f"Error downloading file: {e}")
        return

    # --- 2. Unzip the file ---
    print(f"\nUnzipping {zip_path}...")
    try:
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(download_dir)
        print("Unzipping complete.")
    except zipfile.BadZipFile:
        print("Error: The downloaded file is not a valid zip file or is corrupted.")
        return

    # --- 3. Clean up the zip file ---
    os.remove(zip_path)
    print(f"Cleaned up {zip_path}.")
    print("--- Dataset setup is complete! ---")

# --- Run the setup function ---
download_and_unzip_div2k()

Dataset 'DIV2K_train_HR' already exists. Skipping download.


In [4]:
# --- Define the new directories ---
TRAIN_SHARP_DIR = os.path.join(ORIGINAL_DATA_DIR, "train", "sharp")
TEST_SHARP_DIR = os.path.join(ORIGINAL_DATA_DIR, "test", "sharp")

def organize_images():
    """
    Moves images from a flat directory into train/sharp and test/sharp subdirectories.
    """
    # Check if the script has already been run
    if os.path.exists(TRAIN_SHARP_DIR) or os.path.exists(TEST_SHARP_DIR):
        print("Train/Test directories seem to already exist. Skipping organization.")
        return

    print("Organizing sharp images into train/test splits...")
    
    # Create the new directories
    os.makedirs(TRAIN_SHARP_DIR, exist_ok=True)
    os.makedirs(TEST_SHARP_DIR, exist_ok=True)
    
    # Get all sharp image files from the source directory
    image_files = sorted([f for f in os.listdir(ORIGINAL_DATA_DIR) if f.endswith('.png')])
    
    if not image_files:
        print(f"Error: No .png images found in {ORIGINAL_DATA_DIR}. Please check the path.")
        return

    # Split the files: first 700 for training, last 100 for testing
    train_files = image_files[:700]
    test_files = image_files[700:]

    # Move the files
    print(f"Moving {len(train_files)} images to {TRAIN_SHARP_DIR}...")
    for filename in tqdm(train_files):
        shutil.move(os.path.join(ORIGINAL_DATA_DIR, filename), os.path.join(TRAIN_SHARP_DIR, filename))
        
    print(f"Moving {len(test_files)} images to {TEST_SHARP_DIR}...")
    for filename in tqdm(test_files):
        shutil.move(os.path.join(ORIGINAL_DATA_DIR, filename), os.path.join(TEST_SHARP_DIR, filename))

    print("Image organization complete!")

# --- Run the organization script ---
organize_images()

Train/Test directories seem to already exist. Skipping organization.


In [5]:
def generate_motion_blur_kernel(size, angle):
    """Generates a linear motion blur kernel."""
    kernel = np.zeros((size, size))
    center = (size - 1) / 2
    radian_angle = np.deg2rad(angle)
    x_end = round(center + center * np.cos(radian_angle))
    y_end = round(center - center * np.sin(radian_angle))
    cv2.line(kernel, (round(center), round(center)), (int(x_end), int(y_end)), 1, thickness=1)
    return kernel / kernel.sum()

def generate_defocus_kernel(size, radius):
    """Generates a defocus (disk) kernel."""
    kernel = np.zeros((size, size))
    center = size // 2
    cv2.circle(kernel, (center, center), radius, 1, thickness=-1)
    return kernel / kernel.sum()

def apply_blur_to_directory(source_dir, target_dir):
    if not os.path.exists(target_dir):
        os.makedirs(target_dir)

    image_files = sorted([f for f in os.listdir(source_dir) if f.endswith(('.png'))])
    
    print(f"Generating blurry images for {len(image_files)} sharp images...")
    
    for filename in tqdm(image_files):
        sharp_path = os.path.join(source_dir, filename)
        image = cv2.imread(sharp_path)
        
        # Randomly choose a blur type
        blur_type = np.random.choice(['motion', 'gaussian', 'defocus'])
        
        if blur_type == 'motion':
            kernel_size = np.random.randint(15, 35)
            kernel_angle = np.random.randint(0, 180)
            kernel = generate_motion_blur_kernel(kernel_size, kernel_angle)
        elif blur_type == 'defocus':
            kernel_size = np.random.randint(15, 35)
            radius = np.random.randint(3, kernel_size // 2)
            kernel = generate_defocus_kernel(kernel_size, radius)
        else: # gaussian
            kernel_size = np.random.choice([15, 19, 21, 25])
            sigma = np.random.uniform(3, 8)
            kernel = cv2.getGaussianKernel(kernel_size, sigma)
            kernel = np.dot(kernel, kernel.T)
        
        blurry_image = cv2.filter2D(image, -1, kernel)
        blur_path = os.path.join(target_dir, filename)
        cv2.imwrite(blur_path, blurry_image)

    print("Blurry images generated successfully.")


for split in ['train', 'test']:
    sharp_folder = os.path.join(ORIGINAL_DATA_DIR, split, 'sharp')
    blur_folder = os.path.join(ORIGINAL_DATA_DIR, split, 'blur')
    if not os.path.exists(blur_folder) or len(os.listdir(blur_folder)) == 0:
        print(f"Generating blurry images for '{split}' set...")
        apply_blur_to_directory(sharp_folder, blur_folder)
    else:
        print(f"Blurry images for '{split}' set already exist. Skipping generation.")

Blurry images for 'train' set already exist. Skipping generation.
Blurry images for 'test' set already exist. Skipping generation.


In [6]:
def create_crops(source_base, target_base, crop_size=(224, 224), crops_per_image=10):
    """Generates and saves random crops from source images."""
    if os.path.exists(target_base) and len(os.listdir(target_base)) > 0:
        print("Cropped dataset directory already exists and is not empty. Skipping cropping.")
        return
        
    print("Starting dataset pre-processing with random crops...")
    for split in ['train', 'test']:
        sharp_source_dir = os.path.join(source_base, split, 'sharp')
        blur_source_dir = os.path.join(source_base, split, 'blur')

        if not os.path.exists(sharp_source_dir):
            continue

        sharp_target_dir = os.path.join(target_base, split, 'sharp')
        blur_target_dir = os.path.join(target_base, split, 'blur')
        os.makedirs(sharp_target_dir, exist_ok=True)
        os.makedirs(blur_target_dir, exist_ok=True)

        print(f"Processing images in: {sharp_source_dir}")
        for filename in tqdm(os.listdir(sharp_source_dir)):
            sharp_img = Image.open(os.path.join(sharp_source_dir, filename)).convert("RGB")
            blur_img = Image.open(os.path.join(blur_source_dir, filename)).convert("RGB")

            img_w, img_h = sharp_img.size
            crop_h, crop_w = crop_size
            if img_w < crop_w or img_h < crop_h:
                continue

            for i in range(crops_per_image):
                top, left = random.randint(0, img_h - crop_h), random.randint(0, img_w - crop_w)
                sharp_cropped = sharp_img.crop((left, top, left + crop_w, top + crop_h))
                blur_cropped = blur_img.crop((left, top, left + crop_w, top + crop_h))
                
                base_name, ext = os.path.splitext(filename)
                new_filename = f"{base_name}_crop_{i}{ext}"
                sharp_cropped.save(os.path.join(sharp_target_dir, new_filename))
                blur_cropped.save(os.path.join(blur_target_dir, new_filename))
    print(f"Cropped images saved to: {target_base}")

# --- Run Cropping ---
create_crops(ORIGINAL_DATA_DIR, CROPPED_DATA_DIR)

Cropped dataset directory already exists and is not empty. Skipping cropping.


In [7]:
def create_validation_split(base_dir, ratio=0.1):
    """Moves a random subset of training images to a new validation directory."""
    train_dir = os.path.join(base_dir, "train")
    val_dir = os.path.join(base_dir, "validation")
    
    if os.path.exists(val_dir):
        print("Validation directory already exists. Skipping creation.")
        return

    print("Creating validation split...")
    os.makedirs(os.path.join(val_dir, "sharp"), exist_ok=True)
    os.makedirs(os.path.join(val_dir, "blur"), exist_ok=True)

    sharp_train_dir = os.path.join(train_dir, "sharp")
    image_files = sorted(os.listdir(sharp_train_dir))
    random.shuffle(image_files)
    num_val_images = int(len(image_files) * ratio)
    val_images = image_files[:num_val_images]

    print(f"Moving {num_val_images} images from train to validation set...")
    for filename in tqdm(val_images):
        shutil.move(os.path.join(train_dir, "sharp", filename), os.path.join(val_dir, "sharp", filename))
        shutil.move(os.path.join(train_dir, "blur", filename), os.path.join(val_dir, "blur", filename))
    print("Validation set created successfully.")

# --- Run Validation Split ---
create_validation_split(CROPPED_DATA_DIR)

Validation directory already exists. Skipping creation.


In [8]:
class DeblurDataset(Dataset):
    """Custom PyTorch Dataset for loading paired deblurring data."""
    def __init__(self, root_dir, split="train"):
        self.split_dir = os.path.join(root_dir, split)
        self.sharp_dir = os.path.join(self.split_dir, 'sharp')
        self.blur_dir = os.path.join(self.split_dir, 'blur')
        self.image_files = sorted(os.listdir(self.sharp_dir))

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Load sharp and blurred images
        img_name = self.image_files[idx]
        sharp_path = os.path.join(self.sharp_dir, img_name)
        blur_path = os.path.join(self.blur_dir, img_name)
        
        # Use RGB format
        sharp_image = Image.open(sharp_path).convert("RGB")
        blur_image = Image.open(blur_path).convert("RGB")

        # Data augmentation by applying random horizontal flip
        if random.random() > 0.5:
            sharp_image = TF.hflip(sharp_image)
            blur_image = TF.hflip(blur_image)
        
        # convert PIL format to tensor format
        sharp_image = TF.to_tensor(sharp_image)
        blur_image = TF.to_tensor(blur_image)
        
        return blur_image, sharp_image
    
# --- Create Datasets and DataLoaders ---
train_dataset = DeblurDataset(root_dir=CROPPED_DATA_DIR, split='train')
val_dataset = DeblurDataset(root_dir=CROPPED_DATA_DIR, split='validation')
test_dataset = DeblurDataset(root_dir=CROPPED_DATA_DIR, split='test')

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=0)

print(f"Training dataset loaded with {len(train_dataset)} images.")
print(f"Validation dataset loaded with {len(val_dataset)} images.")
print(f"Test dataset loaded with {len(test_dataset)} images.")

Training dataset loaded with 6300 images.
Validation dataset loaded with 700 images.
Test dataset loaded with 1000 images.


In [9]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        return self.conv(x)

class ViTDecoder(nn.Module):
    def __init__(self, in_features=768, num_classes=3):
        super().__init__()
        
        self.upconv1 = nn.ConvTranspose2d(in_features, 256, kernel_size=2, stride=2) # 14x14 -> 28x28
        self.conv1 = ConvBlock(256, 256)
        
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) # 28x28 -> 56x56
        self.conv2 = ConvBlock(128, 128)

        self.upconv3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) # 56x56 -> 112x112
        self.conv3 = ConvBlock(64, 64)

        self.upconv4 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2) # 112x112 -> 224x224
        self.conv4 = ConvBlock(32, 32)
        
        self.final_conv = nn.Conv2d(32, num_classes, kernel_size=1)

    def forward(self, x):
        patches = x[:, 1:, :]
        
        h = w = int(patches.shape[1]**0.5)
        patches = patches.permute(0, 2, 1).contiguous().view(-1, 768, h, w)

        x = self.conv1(self.upconv1(patches))
        x = self.conv2(self.upconv2(x))
        x = self.conv3(self.upconv3(x))
        x = self.conv4(self.upconv4(x))
        
        x = self.final_conv(x)
        return torch.sigmoid(x)

class ViTForDeblurring(nn.Module):
    def __init__(self, freeze_encoder=True, num_unfrozen_layers=4):
        super().__init__()
        self.encoder = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        self.decoder = ViTDecoder()

        if freeze_encoder:
            for param in self.encoder.parameters():
                param.requires_grad = False
        else:
            print(f"Fine-tuning ViT: Unfreezing the last {num_unfrozen_layers} transformer layers.")
            for param in self.encoder.parameters():
                param.requires_grad = False

            for layer in self.encoder.encoder.layer[-num_unfrozen_layers:]:
                for param in layer.parameters():
                    param.requires_grad = True
            
            for param in self.encoder.layernorm.parameters():
                 param.requires_grad = True

    def forward(self, x):
        encoder_output = self.encoder(x)
        last_hidden_state = encoder_output.last_hidden_state
        decoded_output = self.decoder(last_hidden_state)
        return decoded_output

In [10]:
class VGGPerceptualLoss(nn.Module):
    """
    A perceptual loss function based on the VGG19 network.
    It computes the L1 loss between the feature maps of the input and target images.
    """
    def __init__(self):
        super(VGGPerceptualLoss, self).__init__()
        # Use features from a shallower layer (conv3_4) for efficiency
        vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features[:18].eval()
        self.features = nn.Sequential(*vgg).to(DEVICE)
        
        # VGG networks require a specific ImageNet normalization
        self.normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

        # Freeze the VGG network parameters
        for param in self.features.parameters():
            param.requires_grad = False
            
        self.l1 = nn.L1Loss()

    def forward(self, input_img, target_img):
        """Computes the perceptual loss."""
        # Normalize both images before feeding them to VGG
        input_norm = self.normalize(input_img)
        target_norm = self.normalize(target_img)

        # Extract features
        input_features = self.features(input_norm)
        target_features = self.features(target_norm)
        
        # Compute the L1 loss between the feature maps
        return self.l1(input_features, target_features)

In [None]:
# Cell 6
def train_model(model, model_name, train_loader, val_loader, optimizer, epochs, device, lambda_vgg=0.01):
    """Main training loop."""
    criterion_l1 = nn.L1Loss()
    criterion_vgg = VGGPerceptualLoss().to(device)
    model.to(device)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)
    
    print(f"--- Starting Training for {model_name} ---")

    for epoch in range(epochs):
        model.train()
        running_l1_loss = 0.0
        running_vgg_loss = 0.0

        for blurry_imgs, sharp_imgs in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]"):
            blurry_imgs, sharp_imgs = blurry_imgs.to(device), sharp_imgs.to(device)

            optimizer.zero_grad()
            outputs = model(blurry_imgs)

            loss_l1 = criterion_l1(outputs, sharp_imgs)
            loss_vgg = criterion_vgg(outputs, sharp_imgs)
            total_loss = loss_l1 + (lambda_vgg * loss_vgg)
            total_loss.backward()
            optimizer.step()

            running_l1_loss += loss_l1.item()
            running_vgg_loss += loss_vgg.item()
        
        avg_train_l1 = running_l1_loss / len(train_loader)
        
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for blurry_imgs, sharp_imgs in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]"):
                blurry_imgs, sharp_imgs = blurry_imgs.to(device), sharp_imgs.to(device)
                outputs = model(blurry_imgs)
                loss = criterion_l1(outputs, sharp_imgs)
                val_loss += loss.item()
        avg_val_loss = val_loss / len(val_loader)
        
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Epoch [{epoch+1}/{epochs}] complete. Train L1: {avg_train_l1:.4f}, Val L1: {avg_val_loss:.4f}, LR: {current_lr}")
        scheduler.step(avg_val_loss)
    
    print(f"--- Finished Training for {model_name} ---")
    torch.save(model.state_dict(), f'{model_name}_deblur.pth')
    print(f"Model saved to {model_name}_deblur.pth")
    return model

In [None]:
# Cell 7
# --- Initialize U-Net from the segmentation-models-pytorch library ---
unet_model = smp.Unet(
    encoder_name="resnet50",
    encoder_weights="imagenet",
    in_channels=3,
    classes=3,
    decoder_use_batchnorm=True,
    decoder_attention_type='scse'
)

optimizer_unet = optim.AdamW(unet_model.parameters(), lr=LEARNING_RATE)
UNET_MODEL_NAME = "unet"
print("U-Net with ResNet50 backbone created successfully.")

# --- Initialize Transformer model ---
vit_model = ViTForDeblurring(freeze_encoder=False)
optimizer_vit = optim.Adam([
    {'params': vit_model.encoder.parameters(), 'lr': 1e-5},  # Lower LR for encoder
    {'params': vit_model.decoder.parameters(), 'lr': 1e-4}   # Higher LR for decoder
])
VIT_MODEL_NAME = "vit"
print("ViT-based model created successfully.")

In [None]:
# --- Train U-Net Baseline ---
trained_unet = train_model(unet_model, UNET_MODEL_NAME, train_loader, val_loader, optimizer_unet, EPOCHS, DEVICE)

In [None]:
# --- Train ViT Model ---
trained_vit = train_model(vit_model, VIT_MODEL_NAME, train_loader, val_loader, optimizer_vit, EPOCHS, DEVICE)

In [None]:
def evaluate_on_cropped(unet_arch, vit_arch, unet_path, vit_path, test_dataset, device, num_images=10):
    """Loads both models and evaluates on random cropped test images."""
    try:
        unet_arch.load_state_dict(torch.load(unet_path, weights_only=True))
        vit_arch.load_state_dict(torch.load(vit_path, weights_only=True))
        unet_arch.to(device).eval()
        vit_arch.to(device).eval()
    except FileNotFoundError as e:
        print(f"Model weights not found: {e}. Skipping evaluation.")
        return

    indices = random.sample(range(len(test_dataset)), num_images)
    
    for i in indices:
        blurry_img, sharp_img = test_dataset[i]
        blurry_img_batch = blurry_img.unsqueeze(0).to(device)

        with torch.no_grad():
            unet_out = unet_arch(blurry_img_batch).squeeze().cpu()
            vit_out = vit_arch(blurry_img_batch).squeeze().cpu()

        sharp_np = sharp_img.permute(1, 2, 0).numpy()
        blurry_np = blurry_img.permute(1, 2, 0).numpy()
        unet_np = unet_out.permute(1, 2, 0).numpy().clip(0, 1)
        vit_np = vit_out.permute(1, 2, 0).numpy().clip(0, 1)

        psnr_unet = psnr(sharp_np, unet_np, data_range=1.0)
        ssim_unet = ssim(sharp_np, unet_np, multichannel=True, data_range=1.0, channel_axis=-1)
        psnr_vit = psnr(sharp_np, vit_np, data_range=1.0)
        ssim_vit = ssim(sharp_np, vit_np, multichannel=True, data_range=1.0, channel_axis=-1)
        
        print(f"--- Image Index {i} ---")
        print(f"U-Net --> PSNR: {psnr_unet:.2f}, SSIM: {ssim_unet:.4f}")
        print(f"ViT ----> PSNR: {psnr_vit:.2f}, SSIM: {ssim_vit:.4f}")

        fig, axes = plt.subplots(1, 4, figsize=(20, 5))
        axes[0].imshow(blurry_np); axes[0].set_title("Blurry Input")
        axes[1].imshow(unet_np); axes[1].set_title(f"U-Net\nPSNR: {psnr_unet:.2f}")
        axes[2].imshow(vit_np); axes[2].set_title(f"ViT\nPSNR: {psnr_vit:.2f}")
        axes[3].imshow(sharp_np); axes[3].set_title("Ground Truth")
        for ax in axes: ax.axis('off')
        plt.show()

# --- Run Cropped Image Evaluation ---
unet_for_eval = smp.Unet("resnet50", in_channels=3, classes=3, decoder_use_batchnorm=True, decoder_attention_type='scse')
vit_for_eval = ViTForDeblurring(freeze_encoder=False)
evaluate_on_cropped(
    unet_for_eval, vit_for_eval, 
    f"{UNET_MODEL_NAME}_deblur.pth", f"{VIT_MODEL_NAME}_deblur.pth", 
    test_dataset, DEVICE
)

In [None]:
# --- Helper functions for patch-based prediction ---
def get_pad(h, w, patch_size, stride):
    pad_h = (stride - (h - patch_size) % stride) % stride
    pad_w = (stride - (w - patch_size) % stride) % stride
    return pad_h, pad_w

def predict_on_full_image(model, full_blurry_tensor, patch_size=224, overlap=32, device='cuda'):
    model.eval()
    full_blurry_tensor = full_blurry_tensor.unsqueeze(0).to(device)
    _, _, h, w = full_blurry_tensor.shape
    stride = patch_size - overlap
    pad_h, pad_w = get_pad(h, w, patch_size, stride)
    padded_blurry = F.pad(full_blurry_tensor, (0, pad_w, 0, pad_h), 'reflect')
    _, _, padded_h, padded_w = padded_blurry.shape
    full_output = torch.zeros_like(padded_blurry)
    count_map = torch.zeros_like(padded_blurry)
    for i in range(0, padded_h - patch_size + 1, stride):
        for j in range(0, padded_w - patch_size + 1, stride):
            patch = padded_blurry[:, :, i:i+patch_size, j:j+patch_size]
            with torch.no_grad():
                deblurred_patch = model(patch)
            full_output[:, :, i:i+patch_size, j:j+patch_size] += deblurred_patch
            count_map[:, :, i:i+patch_size, j:j+patch_size] += 1
    final_output = (full_output / count_map)[:, :, :h, :w].squeeze(0).cpu()
    return final_output

def evaluate_on_full_images(unet_arch, vit_arch, unet_path, vit_path, original_test_dir, device, num_images=10):
    """Loads both models and evaluates on random full-resolution test images."""
    try:
        unet_arch.load_state_dict(torch.load(unet_path, weights_only=True))
        vit_arch.load_state_dict(torch.load(vit_path, weights_only=True))
        unet_arch.to(device).eval()
        vit_arch.to(device).eval()
    except FileNotFoundError as e:
        print(f"Model weights not found: {e}. Skipping evaluation.")
        return

    sharp_dir = os.path.join(original_test_dir, 'test', 'sharp')
    blur_dir = os.path.join(original_test_dir, 'test', 'blur')
    
    try:
        image_names = os.listdir(sharp_dir)
        selected_images = random.sample(image_names, min(num_images, len(image_names)))
    except FileNotFoundError:
        print(f"Test images not found in '{sharp_dir}'. Skipping.")
        return

    for image_name in selected_images:
        print(f"--- Processing Full Image: {image_name} ---")
        blurry_tensor = TF.to_tensor(Image.open(os.path.join(blur_dir, image_name)).convert("RGB"))
        sharp_tensor = TF.to_tensor(Image.open(os.path.join(sharp_dir, image_name)).convert("RGB"))
        
        unet_output = predict_on_full_image(unet_arch, blurry_tensor, device=device)
        vit_output = predict_on_full_image(vit_arch, blurry_tensor, device=device)

        sharp_np = sharp_tensor.permute(1, 2, 0).numpy()
        blurry_np = blurry_tensor.permute(1, 2, 0).numpy()
        unet_np = unet_output.permute(1, 2, 0).numpy().clip(0, 1)
        vit_np = vit_output.permute(1, 2, 0).numpy().clip(0, 1)

        psnr_unet = psnr(sharp_np, unet_np, data_range=1.0)
        ssim_unet = ssim(sharp_np, unet_np, multichannel=True, data_range=1.0, channel_axis=-1)
        psnr_vit = psnr(sharp_np, vit_np, data_range=1.0)
        ssim_vit = ssim(sharp_np, vit_np, multichannel=True, data_range=1.0, channel_axis=-1)
        
        print(f"U-Net --> PSNR: {psnr_unet:.2f}, SSIM: {ssim_unet:.4f}")
        print(f"ViT ----> PSNR: {psnr_vit:.2f}, SSIM: {ssim_vit:.4f}")

        fig, axes = plt.subplots(1, 4, figsize=(20, 5))
        axes[0].imshow(blurry_np); axes[0].set_title("Blurry Input")
        axes[1].imshow(unet_np); axes[1].set_title(f"U-Net\nPSNR: {psnr_unet:.2f}")
        axes[2].imshow(vit_np); axes[2].set_title(f"ViT\nPSNR: {psnr_vit:.2f}")
        axes[3].imshow(sharp_np); axes[3].set_title("Ground Truth")
        for ax in axes: ax.axis('off')
        plt.show()

# --- Run Full Image Evaluation ---
unet_for_eval_full = smp.Unet("resnet50", in_channels=3, classes=3, decoder_use_batchnorm=True, decoder_attention_type='scse')
vit_for_eval_full = ViTForDeblurring(freeze_encoder=False)
evaluate_on_full_images(
    unet_for_eval_full, vit_for_eval_full,
    f"{UNET_MODEL_NAME}_deblur.pth", f"{VIT_MODEL_NAME}_deblur.pth",
    ORIGINAL_DATA_DIR, DEVICE
)