In [1]:
import os
# Display current working directory
print(os.getcwd())
# To make sure opencv imports .exr files
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
# If the current directory is not WDSS, then set it to one level up
if os.getcwd()[-4:] != 'WDSS':
    os.chdir('..')
print(os.getcwd())

c:\Dev\MinorProject\WDSS\jupyter_notebooks
c:\Dev\MinorProject\WDSS


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import numpy as np
import cv2
import matplotlib.pyplot as plt
import time

from typing import List, Tuple, Dict

from config import device, Settings
from commons import initialize

In [3]:
video_out_dir = "out/video"

In [4]:
settings = Settings("config/config.json", "WDSSV5")
initialize(settings=settings)

Job: Navin_relu_mse, Model: WDSSV5, Device: cuda
Model path: out\Navin_relu_mse-WDSSV5\model
Log path: out\Navin_relu_mse-WDSSV5\logs


In [5]:
from network.dataset import *

train_dataset, val_dataset, test_dataset = WDSSDatasetCompressed.get_datasets(settings)

In [6]:
from network.models.WDSS import get_wdss_model
    
# Model
model = get_wdss_model(settings.model_config).to(device)

In [7]:
from network.losses import CriterionSSIM_L1, CriterionSSIM_MSE

criterion = CriterionSSIM_MSE().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=20, gamma=0.5)

Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]




Loading model from: c:\Dev\MinorProject\WDSS\.venv\Lib\site-packages\lpips\weights\v0.1\vgg.pth


  self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)


In [8]:
from network.trainer import Trainer

trainer = Trainer(settings, model, optimizer, scheduler, criterion, train_dataset, val_dataset, test_dataset)

In [9]:
try:
    trainer.load_best_checkpoint()
    print(f"Checkpoint loaded epoch: {trainer.total_epochs}")
except:
    print("No checkpoint found")

Checkpoint loaded epoch: 77


In [10]:
total_frames = 120
print(f"Total frames: {total_frames}")

Total frames: 120


In [11]:
from tqdm import tqdm

trainer.model.eval()

for frame_no in tqdm(range(total_frames)):
    raw_frames = test_dataset[frame_no]
    lr = raw_frames['LR'].to(device).unsqueeze(0)
    gb = raw_frames['GB'].to(device).unsqueeze(0)
    temp = raw_frames['TEMPORAL'].to(device).unsqueeze(0)
    hr = raw_frames['HR'].to(device).unsqueeze(0)

    with torch.no_grad():
        wavelet, image = trainer.model.forward(lr, gb, temp, 2.0)

    # Store the output images
    res_cv = ImageUtils.tensor_to_opencv_image(image.detach().cpu().clamp(0, 1))
    res_cv = (res_cv * 255).astype(np.uint8)
    res_cv = res_cv[..., [2, 1, 0]]  # Swap red and blue channels
    cv2.imwrite(f"{video_out_dir}/res/frame_{frame_no:04d}.jpg", res_cv)
    
    hr_cv = ImageUtils.tensor_to_opencv_image(hr.detach().cpu().clamp(0, 1))
    hr_cv = (hr_cv * 255).astype(np.uint8)
    hr_cv = hr_cv[..., [2, 1, 0]]  # Swap red and blue channels
    cv2.imwrite(f"{video_out_dir}/hr/frame_{frame_no:04d}.jpg", hr_cv)

    # LR frame
    lr_cv = ImageUtils.tensor_to_opencv_image(lr.detach().cpu().clamp(0, 1))
    lr_cv = (lr_cv * 255).astype(np.uint8)
    lr_cv = lr_cv[..., [2, 1, 0]]  # Swap red and blue channels
    cv2.imwrite(f"{video_out_dir}/lr/frame_{frame_no:04d}.jpg", lr_cv)

    # cv2.imwrite("out/video/1.png", res_cv)

    # ImageUtils.display_images([res_cv, hr_cv])
    


100%|██████████| 120/120 [03:31<00:00,  1.76s/it]


In [14]:
# Load the output images and create a video
import cv2
import os

image_folder = f"{video_out_dir}/res"
video_name = f"{video_out_dir}/res_video.mp4"

images = [img for img in os.listdir(image_folder) if img.endswith(".jpg")]
frame = cv2.imread(os.path.join(image_folder, images[0]))
height, width, layers = frame.shape

# Define the codec and create VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video = cv2.VideoWriter(video_name, fourcc, 60, (width, height))

for image in images:
    video.write(cv2.imread(os.path.join(image_folder, image)))

cv2.destroyAllWindows()
video.release()


In [19]:
# Compute average losses
from network.losses import ImageEvaluator

total_ssim = 0.0
total_mse = 0.0
total_l1 = 0.0
total_psnr = 0.0
total_lpips = 0.0
total_wavelet_ssim = 0.0
total_wavelet_mse = 0.0
total_wavelet_l1 = 0.0
total_wavelet_psnr = 0.0
total_wavelet_lpips = 0.0

total_ssim_bilinear = 0.0
total_mse_bilinear = 0.0
total_l1_bilinear = 0.0
total_psnr_bilinear = 0.0
total_lpips_bilinear = 0.0

In [20]:
from utils.wavelet import WaveletProcessor
from tqdm import tqdm

for i in tqdm(range(len(test_dataset))):
    raw_frames = test_dataset[i]
    lr = raw_frames['LR'].to(device).unsqueeze(0)
    gb = raw_frames['GB'].to(device).unsqueeze(0)
    temp = raw_frames['TEMPORAL'].to(device).unsqueeze(0)
    hr = raw_frames['HR'].to(device).unsqueeze(0)

    with torch.no_grad():
        wavelet, image = trainer.model.forward(lr, gb, temp, 2.0)

    # Bilinearly upsample the LR image
    ups = ImageUtils.upsample(lr, 2)
    # Wavelet transform the hr image
    hr_wavelet = WaveletProcessor.batch_wt(hr)

    # Compute the metrics
    total_mse += ImageEvaluator.mse(image, hr).item()
    total_ssim += ImageEvaluator.ssim(image, hr).item()
    total_l1 += ImageEvaluator.l1(image, hr).item()
    total_psnr += ImageEvaluator.psnr(image, hr, 1.0).item()
    total_lpips += ImageEvaluator.lpips(image, hr).item()

    total_wavelet_mse += ImageEvaluator.mse(wavelet, hr_wavelet).item()
    total_wavelet_ssim += ((ImageEvaluator.ssim(wavelet[:, 0:3, :, :], hr_wavelet[:, 0:3, :, :]) + ImageEvaluator.ssim(wavelet[:, 3:6, :, :], hr_wavelet[:, 3:6, :, :]) + ImageEvaluator.ssim(wavelet[:, 6:9, :, :], hr_wavelet[:, 6:9, :, :]) + ImageEvaluator.ssim(wavelet[:, 9:12, :, :], hr_wavelet[:, 9:12, :, :])) / 4).item()
    total_wavelet_l1 += ImageEvaluator.l1(wavelet, hr_wavelet).item()
    total_wavelet_psnr += ImageEvaluator.psnr(wavelet, hr_wavelet, 2.0).item()
    total_wavelet_lpips += ((ImageEvaluator.lpips(wavelet[:, 0:3, :, :], hr_wavelet[:, 0:3, :, :]) + ImageEvaluator.lpips(wavelet[:, 3:6, :, :], hr_wavelet[:, 3:6, :, :]) + ImageEvaluator.lpips(wavelet[:, 6:9, :, :], hr_wavelet[:, 6:9, :, :]) + ImageEvaluator.lpips(wavelet[:, 9:12, :, :], hr_wavelet[:, 9:12, :, :])) / 4).item()

    total_mse_bilinear += ImageEvaluator.mse(ups, hr).item()
    total_ssim_bilinear += ImageEvaluator.ssim(ups, hr).item()
    total_l1_bilinear += ImageEvaluator.l1(ups, hr).item()
    total_psnr_bilinear += ImageEvaluator.psnr(ups, hr, 1.0).item()
    total_lpips_bilinear += ImageEvaluator.lpips(ups, hr).item()

  0%|          | 0/238 [00:00<?, ?it/s]

100%|██████████| 238/238 [14:30<00:00,  3.66s/it]


In [21]:
# Print all
print("Metrics for the model")
print(f"SSIM: {total_ssim / len(test_dataset)}")
print(f"MSE: {total_mse / len(test_dataset)}")
print(f"L1: {total_l1 / len(test_dataset)}")
print(f"PSNR: {total_psnr / len(test_dataset)}")
print(f"LPIPS: {total_lpips / len(test_dataset)}")

print("Metrics for the wavelet model")
print(f"SSIM: {total_wavelet_ssim / len(test_dataset)}")
print(f"MSE: {total_wavelet_mse / len(test_dataset)}")
print(f"L1: {total_wavelet_l1 / len(test_dataset)}")
print(f"PSNR: {total_wavelet_psnr / len(test_dataset)}")
print(f"LPIPS: {total_wavelet_lpips / len(test_dataset)}")

print("Metrics for the bilinear model")
print(f"SSIM: {total_ssim_bilinear / len(test_dataset)}")
print(f"MSE: {total_mse_bilinear / len(test_dataset)}")
print(f"L1: {total_l1_bilinear / len(test_dataset)}")
print(f"PSNR: {total_psnr_bilinear / len(test_dataset)}")
print(f"LPIPS: {total_lpips_bilinear / len(test_dataset)}")


Metrics for the model
SSIM: 0.9375269112466764
MSE: 0.0005352461289055841
L1: 0.013167775729123284
PSNR: 33.24649676956049
LPIPS: 0.19275168110342586
Metrics for the wavelet model
SSIM: 0.9348329733900663
MSE: 0.0006377929279276467
L1: 0.00947199624517009
PSNR: 39.104559096969474
LPIPS: 0.13105929968487315
Metrics for the bilinear model
SSIM: 0.9360964976939834
MSE: 0.00046700571506296895
L1: 0.010743548293371036
PSNR: 34.860456667026554
LPIPS: 0.18343993548710807
