In [None]:
# ## Image Deblurring: Inference with Pretrained MPRNet
# This notebook demonstrates how to download an MPRNet model from Torch Hub, apply it to a blurred image, and save the result.

In [None]:
import torch
from PIL import Image
import torchvision.transforms as T
import numpy as np
import matplotlib.pyplot as plt

In [None]:
def preprocess(image):
    transform = T.Compose([
        T.ToTensor(),
    ])
    return transform(image).unsqueeze(0)

def postprocess(tensor):
    img = tensor.squeeze(0).cpu().clamp(0, 1).numpy()
    img = np.transpose(img, (1, 2, 0)) * 255
    return Image.fromarray(img.astype(np.uint8))

In [None]:
# --- Loading the MPRNet model ---

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = torch.hub.load('swz30/MPRNet', 'MPRNet', pretrained=True)
model.to(device).eval()

In [None]:
# --- Uploading a blurred image ---
input_path = 'path_to_blurred_image.jpg'  
blurred_img = Image.open(input_path).convert('RGB')

plt.figure(figsize=(8, 8))
plt.title('Input Blurred Image')
plt.axis('off')
plt.imshow(blurred_img)
plt.show()

In [None]:
# --- Inference---
input_tensor = preprocess(blurred_img).to(device)

with torch.no_grad():
    output_tensor = model(input_tensor)

deblurred_img = postprocess(output_tensor)

plt.figure(figsize=(8, 8))
plt.title('Deblurred Output Image')
plt.axis('off')
plt.imshow(deblurred_img)
plt.show()

In [None]:
# --- Saving the result ---
output_path = 'deblurred_output.jpg'
deblurred_img.save(output_path)
print(f"Deblurred image saved to {output_path}")