In [None]:
import os
from pathlib import Path
import torch
from tqdm.auto import tqdm
from torchvision.transforms import functional as TF
from torchvision.utils import save_image

from data import PairedSRDataset, build_loader
from model import SRResNet
from metrics import psnr, ssim

# Config
cfg = {
    'scale': 4,
    'hr_crop': 128,
    'paths': {
        'train_lr': 'train/train_lr',
        'train_hr': 'train/train_hr',
        'val_lr': 'val/val_lr',
        'val_hr': 'val/val_hr',
        'test_lr': 'test/test_lr',
        'test_hr': 'test/test_hr',
    },
}

# Resolve base dir so notebook can run from any cwd
base_dir = None
cwd = Path.cwd().resolve()
for parent in [cwd] + list(cwd.parents):
    if (parent / 'train' / 'train_lr').is_dir():
        base_dir = parent
        break
if base_dir is None:
    raise FileNotFoundError(f"Cannot find 'train/train_lr' from cwd: {cwd}")

cfg['paths']['train_lr'] = str(base_dir / 'train' / 'train_lr')
cfg['paths']['train_hr'] = str(base_dir / 'train' / 'train_hr')
cfg['paths']['val_lr'] = str(base_dir / 'val' / 'val_lr')
cfg['paths']['val_hr'] = str(base_dir / 'val' / 'val_hr')
cfg['paths']['test_lr'] = str(base_dir / 'test' / 'test_lr')
cfg['paths']['test_hr'] = str(base_dir / 'test' / 'test_hr')

weight_path = r"C:\\Users\\VietHoang\\Desktop\\Cuoi_Ki_DL\\weights\\best_srresnet.pth"
if not os.path.isfile(weight_path):
    raise FileNotFoundError(f"Weight not found: {weight_path}")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def load_srresnet(weight_path, scale, device):
    model = SRResNet(scale=scale).to(device)
    ckpt = torch.load(weight_path, map_location=device)
    if isinstance(ckpt, dict) and any(k in ckpt for k in ['state_dict', 'model', 'generator']):
        if 'state_dict' in ckpt:
            state = ckpt['state_dict']
        elif 'model' in ckpt:
            state = ckpt['model']
        else:
            state = ckpt['generator']
    else:
        state = ckpt

    # Strip common prefixes if present
    if isinstance(state, dict):
        if any(k.startswith('module.') for k in state):
            state = {k.replace('module.', '', 1): v for k, v in state.items()}
        if any(k.startswith('_orig_mod.') for k in state):
            state = {k.replace('_orig_mod.', '', 1): v for k, v in state.items()}

    missing, unexpected = model.load_state_dict(state, strict=False)
    if missing or unexpected:
        print('Missing keys:', missing)
        print('Unexpected keys:', unexpected)
    model.eval()
    return model

model = load_srresnet(weight_path, cfg['scale'], device)

# Block 1: evaluate metrics on test set
_, test_loader = build_loader(
    cfg['paths']['test_lr'], cfg['paths']['test_hr'],
    scale=cfg['scale'], hr_crop=cfg['hr_crop'],
    batch_size=1, num_workers=4, train=False
)

psnr_vals = []
ssim_vals = []
with torch.no_grad():
    for lr, hr in tqdm(test_loader, desc='Test'):
        lr = lr.to(device)
        hr = hr.to(device)
        sr = model(lr)
        psnr_vals.append(psnr(sr, hr))
        ssim_vals.append(ssim(sr, hr))

avg_psnr = sum(psnr_vals) / max(len(psnr_vals), 1)
avg_ssim = sum(ssim_vals) / max(len(ssim_vals), 1)
print(f'Test PSNR: {avg_psnr:.4f} dB')
print(f'Test SSIM: {avg_ssim:.4f}')


In [None]:
# Block 2: save SR images for test/train/val
def save_sr_split(split_name, lr_dir, hr_dir, out_dir):
    os.makedirs(out_dir, exist_ok=True)
    dataset = PairedSRDataset(lr_dir, hr_dir, scale=cfg['scale'], hr_crop=cfg['hr_crop'], train=False)
    pairs = dataset.pairs

    with torch.no_grad():
        for lr_path, hr_path in tqdm(pairs, desc=f'Saving {split_name}'):
            lr_img = TF.to_tensor(dataset._load(lr_path)).unsqueeze(0).to(device)
            sr = model(lr_img).cpu()

            base = os.path.splitext(os.path.basename(hr_path))[0]
            if not base:
                base = os.path.splitext(os.path.basename(lr_path))[0]
            out_path = os.path.join(out_dir, f"{base}_sr.png")
            save_image(sr, out_path)

sr_root = base_dir / 'sr'
save_sr_split('test', cfg['paths']['test_lr'], cfg['paths']['test_hr'], str(sr_root / 'test'))
save_sr_split('train', cfg['paths']['train_lr'], cfg['paths']['train_hr'], str(sr_root / 'train'))
save_sr_split('val', cfg['paths']['val_lr'], cfg['paths']['val_hr'], str(sr_root / 'val'))

print('Done saving SR images.')
