# 03_evaluate_zero_shot.ipynb

Test on unseen noise level (zero-shot), visualize results, compute PSNR.

In [None]:
import torch
from src.models import TinyUNet
from src.diffusion import SimpleDiffusion
from src.metrics import psnr
from torchvision import transforms
import os
from PIL import Image
import matplotlib.pyplot as plt

img_dir = '../data/processed'
img_size = 32
transform = transforms.Compose([
    transforms.ToTensor()
])
img_list = []
for fname in os.listdir(img_dir):
    if fname.endswith('.jpg') or fname.endswith('.png'):
        img = Image.open(os.path.join(img_dir, fname)).convert('RGB')
        img = img.resize((img_size, img_size))
        img = transform(img)
        img_list.append(img)
imgs = torch.stack(img_list)

# Add *unseen* noise (higher sigma)
noisy_imgs = imgs + 0.4 * torch.randn_like(imgs)
noisy_imgs = torch.clamp(noisy_imgs, 0., 1.)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = TinyUNet().to(device)
model.load_state_dict(torch.load('../results/denoised/tinyunet.pth', map_location=device))
model.eval()

denoised_imgs = []
for i in range(noisy_imgs.size(0)):
    noisy = noisy_imgs[i:i+1].to(device)
    with torch.no_grad():
        pred = model(noisy)
    denoised_imgs.append(pred.cpu().squeeze(0))
denoised_imgs = torch.stack(denoised_imgs)

# Visualize and compute PSNR
for i in range(5):
    fig, axs = plt.subplots(1,3,figsize=(12,4))
    axs[0].imshow(imgs[i].permute(1,2,0))
    axs[0].set_title('Original')
    axs[1].imshow(noisy_imgs[i].permute(1,2,0))
    axs[1].set_title('Noisy (unseen sigma)')
    axs[2].imshow(denoised_imgs[i].permute(1,2,0).clamp(0,1))
    axs[2].set_title('Denoised')
    for a in axs: a.axis('off')
    plt.show()
    print(f'PSNR: {psnr(imgs[i], denoised_imgs[i]):.2f} dB')