In [1]:
import torch
import torch.nn as nn
import cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Tuple

from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim


In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device", DEVICE)

[INFO] Using device cuda


In [3]:
def read_image(path: Path) -> np.ndarray:
    img = cv2.imread((str(Path)))
    if img is None:
        raise ValueError(f"Failed to read the image {path}")
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

def to_tensor(img: np.ndarray) -> torch.Tensor:
    img = img.astype(np.float32) / 255.0
    img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0)
    return img

def to_numpy(t: torch.Tensor) -> np.ndarray:
    img = t.squeeze(0).permute(1,2,0).cpu().numpy()
    return (img * 255.0).clip(0,255).astype(np.uint8)

In [4]:
class ResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return x + self.conv2(self.relu(self.conv1(x)))

class DeblurGenerator(nn.Module):
    def __init__(self, n_resblocks=9):
        super().__init__()
        self.head = nn.Sequential(
            nn.Conv2d(3, 64, 7, padding=3),
            nn.ReLU(inplace=True)
        )

        self.body = nn.Sequential(
            *[ResBlock(64) for _ in range(n_resblocks)]
        )

        self.tail = nn.Sequential(
            nn.Conv2d(64, 3, 7, padding=3),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.head(x)
        x = self.body(x)
        x = self.tail(x)
        return (x + 1) / 2  # [-1,1] → [0,1]


In [5]:
MODEL_PATH = Path("../outputs/models/deblur/best_model.pth")

model = DeblurGenerator().to(DEVICE)
model.eval()

if MODEL_PATH.exists():
    state = torch.load(MODEL_PATH, map_location=DEVICE)
    model.load_state_dict(state["model_state"])
    print("✅ Loaded pretrained deblur model")
else:
    print("⚠️ No pretrained weights found. Model is untrained.")


⚠️ No pretrained weights found. Model is untrained.


In [6]:
@torch.no_grad()
def deblur_frame(frame: np.ndarray) -> np.ndarray:
    """
    Input: RGB image (H, W, 3)
    Output: Deblurred RGB image
    """
    inp = to_tensor(frame).to(DEVICE)
    out = model(inp)
    return to_numpy(out)


In [7]:
def visualize_deblur(img_path: Path):
    img = read_image(img_path)
    out = deblur_frame(img)

    plt.figure(figsize=(10,4))
    plt.subplot(1,2,1)
    plt.imshow(img)
    plt.title("Original (Blurred)")
    plt.axis("off")

    plt.subplot(1,2,2)
    plt.imshow(out)
    plt.title("Deblurred")
    plt.axis("off")

    plt.show()


In [8]:
def evaluate_deblur(blur_img, sharp_img):
    deblurred = deblur_frame(blur_img)

    psnr_val = psnr(sharp_img, deblurred)
    ssim_val = ssim(
        sharp_img, deblurred,
        channel_axis=2
    )

    return psnr_val, ssim_val


In [9]:
def batch_deblur_test(
    blur_dir: Path,
    sharp_dir: Path,
    max_images: int = 20
):
    blur_imgs = sorted(blur_dir.glob("*.png"))[:max_images]

    psnr_list, ssim_list = [], []

    for b in blur_imgs:
        s = sharp_dir / b.name
        if not s.exists():
            continue

        blur = read_image(b)
        sharp = read_image(s)

        p, s_ = evaluate_deblur(blur, sharp)
        psnr_list.append(p)
        ssim_list.append(s_)

    print(f"Avg PSNR: {np.mean(psnr_list):.2f}")
    print(f"Avg SSIM: {np.mean(ssim_list):.4f}")


In [10]:
def enhancement_pipeline(frame: np.ndarray, blur_flag: bool):
    if blur_flag:
        frame = deblur_frame(frame)
    return frame
