In [1]:
# Deblur_Evaluation.ipynb

# ## Evaluating the quality of de-blurring on a pair of images
# We calculate the PSNR and SSIM metrics between the original (sharp) and restored images.

In [None]:
import os
from glob import glob
from PIL import Image
import numpy as np
import torch
import torchvision.transforms as T
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import matplotlib.pyplot as plt

In [None]:
# --- Preprocessing functions ---

def preprocess(image):
    transform = T.Compose([T.ToTensor()])
    return transform(image).unsqueeze(0)

In [None]:
# --- Post-processing functions ---

def postprocess(tensor):
    img = tensor.squeeze(0).cpu().clamp(0, 1).numpy()
    img = np.transpose(img, (1, 2, 0)) * 255
    return img.astype(np.uint8)

In [None]:
# --- Loading the model ---

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.hub.load('swz30/MPRNet', 'MPRNet', pretrained=True)
model.to(device).eval()

In [None]:
# --- Paths to test image folders ---

blur_dir = 'dataset/blurred'   
sharp_dir = 'dataset/sharp'    

blur_images = sorted(glob(os.path.join(blur_dir, '*')))
sharp_images = sorted(glob(os.path.join(sharp_dir, '*')))

psnr_list = []
ssim_list = []

In [None]:
for blur_path, sharp_path in zip(blur_images, sharp_images):
    blur_img = Image.open(blur_path).convert('RGB')
    sharp_img = Image.open(sharp_path).convert('RGB')
    
    # Model inference
    input_tensor = preprocess(blur_img).to(device)
    with torch.no_grad():
        output_tensor = model(input_tensor)
    output_img = postprocess(output_tensor)

    sharp_np = np.array(sharp_img)

    # Metrics calculation
    current_psnr = psnr(sharp_np, output_img, data_range=255)
    current_ssim = ssim(sharp_np, output_img, multichannel=True, data_range=255)

    psnr_list.append(current_psnr)
    ssim_list.append(current_ssim)

    # Output results for each file
    print(f"{os.path.basename(blur_path)} -> PSNR: {current_psnr:.2f}, SSIM: {current_ssim:.4f}")

    # Visualization (optional)
    fig, axs = plt.subplots(1, 3, figsize=(15,5))
    axs[0].imshow(blur_img); axs[0].set_title('Blurred'); axs[0].axis('off')
    axs[1].imshow(output_img); axs[1].set_title('Deblurred'); axs[1].axis('off')
    axs[2].imshow(sharp_img); axs[2].set_title('Sharp (GT)'); axs[2].axis('off')
    plt.show()


In [None]:
print(f"Average PSNR: {np.mean(psnr_list):.2f}")
print(f"Average SSIM: {np.mean(ssim_list):.4f}")