In [1]:
import os
# Display current working directory
print(os.getcwd())
# To make sure opencv imports .exr files
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
# If the current directory is not WDSS, then set it to one level up
if os.getcwd()[-4:] != 'WDSS':
    os.chdir('..')
print(os.getcwd())

c:\Dev\MinorProject\WDSS\jupyter_notebooks
c:\Dev\MinorProject\WDSS


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import numpy as np
import cv2
import matplotlib.pyplot as plt
import time

from typing import List, Tuple, Dict

from config import device, Settings
from commons import initialize

In [3]:
video_out_dir = "out/video"

In [4]:
settings = Settings("config/config.json", "WDSSV5")
initialize(settings=settings)

Job: Navin_relu_mse, Model: WDSSV5, Device: cuda
Model path: out\Navin_relu_mse-WDSSV5\model
Log path: out\Navin_relu_mse-WDSSV5\logs


In [5]:
from network.dataset import *

train_dataset, val_dataset, test_dataset = WDSSDatasetCompressed.get_datasets(settings)

In [6]:
from network.models.WDSS import get_wdss_model
    
# Model
model = get_wdss_model(settings.model_config).to(device)

In [7]:
from network.losses import CriterionSSIM_L1, CriterionSSIM_MSE

criterion = CriterionSSIM_MSE().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=20, gamma=0.5)

Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]




Loading model from: c:\Dev\MinorProject\WDSS\.venv\Lib\site-packages\lpips\weights\v0.1\vgg.pth


  self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)


In [8]:
from network.trainer import Trainer

trainer = Trainer(settings, model, optimizer, scheduler, criterion, train_dataset, val_dataset, test_dataset)

In [9]:
try:
    trainer.load_best_checkpoint()
    print(f"Checkpoint loaded epoch: {trainer.total_epochs}")
except:
    print("No checkpoint found")

Checkpoint loaded epoch: 77


In [10]:
total_frames = 120
print(f"Total frames: {total_frames}")

Total frames: 120


In [None]:
from tqdm import tqdm

trainer.model.eval()

for frame_no in tqdm(range(total_frames)):
    raw_frames = test_dataset[frame_no]
    lr = raw_frames['LR'].to(device).unsqueeze(0)
    gb = raw_frames['GB'].to(device).unsqueeze(0)
    temp = raw_frames['TEMPORAL'].to(device).unsqueeze(0)
    hr = raw_frames['HR'].to(device).unsqueeze(0)

    with torch.no_grad():
        wavelet, image = trainer.model.forward(lr, gb, temp, 2.0)

    # Store the output images
    res_cv = ImageUtils.tensor_to_opencv_image(image.detach().cpu().clamp(0, 1))
    res_cv = (res_cv * 255).astype(np.uint8)
    res_cv = res_cv[..., [2, 1, 0]]  # Swap red and blue channels
    cv2.imwrite(f"{video_out_dir}/res/frame_{frame_no:04d}.jpg", res_cv)
    
    hr_cv = ImageUtils.tensor_to_opencv_image(hr.detach().cpu().clamp(0, 1))
    hr_cv = (hr_cv * 255).astype(np.uint8)
    hr_cv = hr_cv[..., [2, 1, 0]]  # Swap red and blue channels
    cv2.imwrite(f"{video_out_dir}/hr/frame_{frame_no:04d}.jpg", hr_cv)

    # cv2.imwrite("out/video/1.png", res_cv)

    # ImageUtils.display_images([res_cv, hr_cv])
    


In [12]:
# Load the output images and create a video
import cv2
import os

image_folder = f"{video_out_dir}/hr"
video_name = f"{video_out_dir}/hr_video.mp4"

images = [img for img in os.listdir(image_folder) if img.endswith(".jpg")]
frame = cv2.imread(os.path.join(image_folder, images[0]))
height, width, layers = frame.shape

# Define the codec and create VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video = cv2.VideoWriter(video_name, fourcc, 60, (width, height))

for image in images:
    video.write(cv2.imread(os.path.join(image_folder, image)))

cv2.destroyAllWindows()
video.release()
