In [16]:
import os
# Display current working directory
print(os.getcwd())
# To make sure opencv imports .exr files
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
# If the current directory is not WDSS, then set it to one level up
if os.getcwd()[-4:] != 'WDSS':
    os.chdir('..')
print(os.getcwd())

c:\Dev\MinorProject\WDSS
c:\Dev\MinorProject\WDSS


In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import numpy as np
import cv2
import matplotlib.pyplot as plt
import time

from typing import List, Tuple, Dict

from config import device, Settings
from commons import initialize

In [3]:
video_out_dir = "out/video"

In [18]:
settings = Settings("config/config.json", "WDSSV5")
initialize(settings=settings)

Job: Navin_relu_mse, Model: WDSSV5, Device: cuda
Model path: out\Navin_relu_mse-WDSSV5\model
Log path: out\Navin_relu_mse-WDSSV5\logs


In [20]:
from network.dataset import *

train_dataset, val_dataset, test_dataset = WDSSDatasetCompressed.get_datasets(settings)
print("Test dataset size: ", len(test_dataset))

Test dataset size:  360


In [21]:
from network.models.WDSS import get_wdss_model
    
# Model
model = get_wdss_model(settings.model_config).to(device)

In [22]:
from network.losses import CriterionSSIM_L1, CriterionSSIM_MSE

criterion = CriterionSSIM_MSE().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=20, gamma=0.5)

In [23]:
from network.trainer import Trainer

trainer = Trainer(settings, model, optimizer, scheduler, criterion, train_dataset, val_dataset, test_dataset)

In [24]:
try:
    trainer.load_best_checkpoint()
    print(f"Checkpoint loaded epoch: {trainer.total_epochs}")
except:
    print("No checkpoint found")

Checkpoint loaded epoch: 77


In [10]:
total_frames = 120
print(f"Total frames: {total_frames}")

Total frames: 120


In [None]:
from tqdm import tqdm

trainer.model.eval()

for frame_no in tqdm(range(total_frames)):
    raw_frames = test_dataset[frame_no]
    lr = raw_frames['LR'].to(device).unsqueeze(0)
    gb = raw_frames['GB'].to(device).unsqueeze(0)
    temp = raw_frames['TEMPORAL'].to(device).unsqueeze(0)
    hr = raw_frames['HR'].to(device).unsqueeze(0)

    with torch.no_grad():
        wavelet, image = trainer.model.forward(lr, gb, temp, 2.0)

    # Store the output images
    res_cv = ImageUtils.tensor_to_opencv_image(image.detach().cpu().clamp(0, 1))
    res_cv = (res_cv * 255).astype(np.uint8)
    res_cv = res_cv[..., [2, 1, 0]]  # Swap red and blue channels
    cv2.imwrite(f"{video_out_dir}/res/frame_{frame_no:04d}.jpg", res_cv)
    
    hr_cv = ImageUtils.tensor_to_opencv_image(hr.detach().cpu().clamp(0, 1))
    hr_cv = (hr_cv * 255).astype(np.uint8)
    hr_cv = hr_cv[..., [2, 1, 0]]  # Swap red and blue channels
    cv2.imwrite(f"{video_out_dir}/hr/frame_{frame_no:04d}.jpg", hr_cv)

    # LR frame
    lr_cv = ImageUtils.tensor_to_opencv_image(lr.detach().cpu().clamp(0, 1))
    lr_cv = (lr_cv * 255).astype(np.uint8)
    lr_cv = lr_cv[..., [2, 1, 0]]  # Swap red and blue channels
    cv2.imwrite(f"{video_out_dir}/lr/frame_{frame_no:04d}.jpg", lr_cv)

    # cv2.imwrite("out/video/1.png", res_cv)

    # ImageUtils.display_images([res_cv, hr_cv])
    


100%|██████████| 120/120 [03:31<00:00,  1.76s/it]


In [14]:
# Load the output images and create a video
import cv2
import os

image_folder = f"{video_out_dir}/res"
video_name = f"{video_out_dir}/res_video.mp4"

images = [img for img in os.listdir(image_folder) if img.endswith(".jpg")]
frame = cv2.imread(os.path.join(image_folder, images[0]))
height, width, layers = frame.shape

# Define the codec and create VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video = cv2.VideoWriter(video_name, fourcc, 60, (width, height))

for image in images:
    video.write(cv2.imread(os.path.join(image_folder, image)))

cv2.destroyAllWindows()
video.release()


In [29]:
# Compute average losses
from network.losses import ImageEvaluator

total_ssim = 0.0
total_mse = 0.0
total_l1 = 0.0
total_psnr = 0.0
total_lpips = 0.0

total_ssim_bilinear = 0.0
total_mse_bilinear = 0.0
total_l1_bilinear = 0.0
total_psnr_bilinear = 0.0
total_lpips_bilinear = 0.0

In [30]:
from utils.wavelet import WaveletProcessor
from tqdm import tqdm

model.eval()
with torch.no_grad():
    for i in tqdm(range(240, 360)):
        frame = test_dataset.get_inference_frame(i)
        log_frames = test_dataset.get_log_frames(i)

        lr = frame['LR'].to(device).unsqueeze(0)
        gb = frame['GB'].to(device).unsqueeze(0)
        temp = frame['TEMPORAL'].to(device).unsqueeze(0)
        hr = frame['HR'].to(device).unsqueeze(0)
        inference = frame['INFERENCE']

        wavelet, image = model.forward(lr, gb, temp, 2.0)

        # Postprocess
        final, frames = test_dataset.preprocessor.postprocess(image, inference)
        gt = log_frames['HR'].to(device).unsqueeze(0)
        lr_frame = log_frames['LR'].to(device).unsqueeze(0)

        # Bilinearly upsample the LR image
        ups = ImageUtils.upsample(lr_frame, 2)

        # Compute the metrics
        total_mse += ImageEvaluator.mse(final, gt).item()
        total_ssim += ImageEvaluator.ssim(final, gt).item()
        total_l1 += ImageEvaluator.l1(final, gt).item()
        total_psnr += ImageEvaluator.psnr(final, gt, 1.0).item()
        total_lpips += ImageEvaluator.lpips(final, gt).item()

        total_ssim_bilinear += ImageEvaluator.ssim(ups, gt).item()
        total_mse_bilinear += ImageEvaluator.mse(ups, gt).item()
        total_l1_bilinear += ImageEvaluator.l1(ups, gt).item()
        total_psnr_bilinear += ImageEvaluator.psnr(ups, gt, 1.0).item()
        total_lpips_bilinear += ImageEvaluator.lpips(ups, gt).item()


100%|██████████| 120/120 [07:42<00:00,  3.85s/it]


In [31]:
# Print all
length = 120

print("Metrics for the model")
print(f"SSIM: {total_ssim / length}")
print(f"MSE: {total_mse / length}")
print(f"L1: {total_l1 / length}")
print(f"PSNR: {total_psnr / length}")
print(f"LPIPS: {total_lpips / length}")

print("Metrics for the bilinear model")
print(f"SSIM: {total_ssim_bilinear / length}")
print(f"MSE: {total_mse_bilinear / length}")
print(f"L1: {total_l1_bilinear / length}")
print(f"PSNR: {total_psnr_bilinear / length}")
print(f"LPIPS: {total_lpips_bilinear / length}")

Metrics for the model
SSIM: 0.8619747459888458
MSE: 0.0015514181092536699
L1: 0.027032192486027878
PSNR: 28.267547941207887
LPIPS: 0.2706432558596134
Metrics for the bilinear model
SSIM: 0.9081082423528035
MSE: 0.0008547476674721111
L1: 0.011790875764563679
PSNR: 31.40296723047892
LPIPS: 0.21754737980663777
