In [1]:
import os, sys
repo_root = os.path.abspath("..")
if repo_root not in sys.path:
    sys.path.append(repo_root)
    
import torch
import importlib
from Model.get_pretrained import get_pretrained_large
import Model.evaluation as evaluation
from Model.evaluation import evaluate_over_precomputed
import lpips


Load pre-computed results。

In [2]:
path_val = "../gen_results/val_all.pt"
path_test = "../gen_results/test_all.pt"
data_val = torch.load(path_val, map_location="cpu")
data_test = torch.load(path_test, map_location="cpu")

data = {}
for key in data_val.keys():
    data[key] = torch.cat((data_val[key], data_test[key]), dim=0)

print(type(data))
print(data.keys())

clean  = data["clean"]
cloudy = data["cloudy"]
pred   = data["pred"]

print("clean:",  clean.shape,  clean.dtype)
print("cloudy:", cloudy.shape, cloudy.dtype)
print("pred:",   pred.shape,   pred.dtype)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

<class 'dict'>
dict_keys(['cloudy', 'clean', 'pred'])
clean: torch.Size([4101, 4, 128, 128]) torch.float32
cloudy: torch.Size([4101, 4, 128, 128]) torch.float32
pred: torch.Size([4101, 4, 128, 128]) torch.float32


Load pretrained models

In [3]:
cloud_enc_pth = '../pretrained/cloud_enc_200e_FullData.pth'
denoiser_pth = '../pretrained/denoiser_200e_FullData.pth'
cloud_encoder, forwarder, denoiser = get_pretrained_large(device=device,
                                                          cloud_enc_pth=cloud_enc_pth,
                                                          denoiser_pth=denoiser_pth)

Pretrained large model loaded successfully.
The model has 5125701 parameters.


Evaluate on val/test sets.

In [4]:
lpips_model = lpips.LPIPS(net='vgg').to(device)
lpips_model.eval()

all_metrics, summary = evaluate_over_precomputed(
    data=data,
    batch_size=32,
    max_val=1.0,
    lpips_model=lpips_model,
    device=device,
)

print("Per-batch metrics example:", all_metrics[0])
print("Dataset summary:", summary)



Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]




Loading model from: D:\PyCharm 2024.2.1\Projects\pythonProject\.venv\Lib\site-packages\lpips\weights\v0.1\vgg.pth


100%|██████████| 129/129 [00:12<00:00,  9.96it/s]

Per-batch metrics example: {'MAE': 0.010509842075407505, 'PSNR': 24.852405548095703, 'SSIM': 0.9145927429199219, 'LPIPS': 0.07596463710069656}
Dataset summary: {'MAE_mean': 0.017053022980690002, 'MAE_std': 0.0037360077258199453, 'PSNR_mean': 22.694576263427734, 'PSNR_std': 1.3518364429473877, 'SSIM_mean': 0.8884437084197998, 'SSIM_std': 0.017557499930262566, 'LPIPS_mean': 0.10063502192497253, 'LPIPS_std': 0.02081654965877533}



