In [1]:
import os
# Display current working directory
print(os.getcwd())
# To make sure opencv imports .exr files
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
# If the current directory is not WDSS, then set it to one level up
if os.getcwd()[-4:] != 'WDSS':
    os.chdir('..')
print(os.getcwd())

c:\Dev\MinorProject\WDSS\jupyter_notebooks
c:\Dev\MinorProject\WDSS


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import numpy as np
import cv2
import matplotlib.pyplot as plt
import time

from typing import List, Tuple, Dict

from config import device, Settings
from commons import initialize
from utils import *
from tqdm import tqdm
from network.losses import ImageEvaluator

In [3]:
settings = Settings("config/config.json", "WDSSV5")
initialize(settings)
ImageEvaluator._setup_lpips()

Job: Navin_relu_mse, Model: WDSSV5, Device: cuda
Model path: out\Navin_relu_mse-WDSSV5\model
Log path: out\Navin_relu_mse-WDSSV5\logs
Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]




Loading model from: c:\Dev\MinorProject\WDSS\.venv\Lib\site-packages\lpips\weights\v0.1\vgg.pth


In [4]:
from network.dataset import *

hable_tonemapper = BaseTonemapper.from_name("Hable-10")
srgb_tonemapper = BaseTonemapper.from_name("SRGB")

preprocessor_hable = Preprocessor.from_config(settings.preprocessor_config)
preprocessor_hable.tonemapper = hable_tonemapper
preprocessor_srgb = Preprocessor.from_config(settings.preprocessor_config)
preprocessor_srgb.tonemapper = srgb_tonemapper

val_hable = WDSSDatasetCompressed(settings.val_dir, settings.frames_per_zip, 0, 2.0, False, 8, preprocessor_hable, True)
val_srgb = WDSSDatasetCompressed(settings.val_dir, settings.frames_per_zip, 0, 2.0, False, 8, preprocessor_srgb, True)
test_hable = WDSSDatasetCompressed(settings.test_dir, settings.frames_per_zip, 0, 2.0, False, 8, preprocessor_hable, True)
test_srgb = WDSSDatasetCompressed(settings.test_dir, settings.frames_per_zip, 0, 2.0, False, 8, preprocessor_srgb, True)

In [5]:
from network.models.WDSS import get_wdss_model
    
# Model
model = get_wdss_model(settings.model_config).to(device)

from network.losses import CriterionSSIM_L1, CriterionSSIM_MSE

criterion = CriterionSSIM_MSE().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=20, gamma=0.5)

from network.trainer import Trainer

trainer = Trainer(settings, model, optimizer, scheduler, criterion, val_hable, val_hable, val_hable)

In [6]:
try:
    # trainer.load_latest_checkpoint()
    trainer.load_best_checkpoint()
    print(f"Checkpoint loaded epoch: {trainer.total_epochs}")
except:
    print("No checkpoint found")

Checkpoint loaded epoch: 77


In [7]:
test_sequences: List[Tuple[str, WDSSDatasetCompressed, int, int]] = [
    ("Asian Village Hable", val_hable, 0, 120),
    ("Industrial Hable", val_hable, 120, 240),
    ("Brass Town Hable", test_hable, 0, 120),
    ("Forest Hable", test_hable, 120, 240),
    ("Supermarket Hable", test_hable, 240, 360),
    ("Asian Village SRGB", val_srgb, 0, 120),
    ("Industrial SRGB", val_srgb, 120, 240),
    ("Brass Town SRGB", test_srgb, 0, 120),
    ("Forest SRGB", test_srgb, 120, 240),
    ("Supermarket SRGB", test_srgb, 240, 360),
]

In [8]:
def score_sequence(name: str, dataset: WDSSDatasetCompressed, start: int, end: int) -> Tuple[float, float, float]:
    print(f'Scoring {name}...')

    model_ssim = 0.0
    model_psnr = 0.0
    model_lpips = 0.0

    model.eval()
    with torch.no_grad():
        for i in tqdm(range(start, end)):
            frame = dataset.get_inference_frame(i)
            
            lr = frame['LR'].to(device).unsqueeze(0)
            gb = frame['GB'].to(device).unsqueeze(0)
            temp = frame['TEMPORAL'].to(device).unsqueeze(0)
            hr = frame['HR'].to(device).unsqueeze(0)
            inference = frame['INFERENCE']

            for key in inference:
                inference[key] = inference[key].to(device).unsqueeze(0)

            wavelet, image = model.forward(lr, gb, temp, 2.0)

            # Postprocess
            pred, _ = dataset.preprocessor.postprocess(image, inference)
            gt, _ = dataset.preprocessor.postprocess(hr, inference)

            model_ssim += ImageEvaluator.ssim(pred, gt).item()
            model_psnr += ImageEvaluator.psnr(pred, gt, 1.0).item()
            model_lpips += ImageEvaluator.lpips(pred, gt).item()

    items = end - start
    print(f'{name} SSIM: {model_ssim / items:.8f}')
    print(f'{name} PSNR: {model_psnr / items:.8f}')
    print(f'{name} LPIPS: {model_lpips / items:.8f}')

    return model_ssim / items, model_psnr / items, model_lpips / items

In [9]:
scores = []

In [10]:
for item in test_sequences:
    name, dataset, start, end = item
    ssim, psnr, lpips = score_sequence(name, dataset, start, end)
    scores.append((name, ssim, psnr, lpips))

Scoring Asian Village Hable...


KeyboardInterrupt: 