In [None]:
import torch
import numpy as np
def denorm(t):
    return (t * 0.5) + 0.5

def to_gray(img):
    return np.dot(img[...,:3], [0.299, 0.587, 0.114])
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

In [None]:
def model_evaluation(model, dataloader, device, max_batches=None, verbose=False):
    model.eval()
    total_psnr = 0
    total_ssim = 0
    count = 0

    with torch.no_grad():
        for i, (blur_img, sharp_img) in enumerate(dataloader):
            if max_batches is not None and i >= max_batches:
                break

            input_tensor = blur_img.to(device)
            target_tensor = sharp_img.to(device)

            output = model(input_tensor).cpu()
            target_tensor = target_tensor.cpu()

            for pred, target in zip(output, target_tensor):
                pred_np = denorm(pred).clamp(0, 1).permute(1, 2, 0).numpy()
                target_np = denorm(target).clamp(0, 1).permute(1, 2, 0).numpy()

                pred_gray = to_gray(pred_np)
                target_gray = to_gray(target_np)

                psnr_val = psnr(target_np, pred_np, data_range=1.0)
                ssim_val = ssim(target_gray, pred_gray, data_range=1.0)

                total_psnr += psnr_val
                total_ssim += ssim_val
                count += 1

                if verbose:
                    print(f"Sample {count}: PSNR={psnr_val:.2f}, SSIM={ssim_val:.4f}")

    if count == 0:
        raise ValueError("No samples were evaluated")

    avg_psnr = total_psnr / count if count != 0 else 0
    avg_ssim = total_ssim / count if count != 0 else 0

    return avg_psnr, avg_ssim