In [1]:
import os
# Display current working directory
print(os.getcwd())
# To make sure opencv imports .exr files
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
# If the current directory is not WDSS, then set it to one level up
if os.getcwd()[-4:] != 'WDSS':
    os.chdir('..')
print(os.getcwd())

c:\Dev\MinorProject\WDSS\jupyter_notebooks
c:\Dev\MinorProject\WDSS


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import numpy as np
import cv2
import matplotlib.pyplot as plt
import time

from typing import List, Tuple, Dict

from config import device, Settings
from commons import initialize

In [3]:
settings = Settings("config/config.json", "WDSSV5")
initialize(settings=settings)

Job: WithoutINRL1, Model: WDSSV5, Device: cuda
Model path: out\WithoutINRL1-WDSSV5\model
Log path: out\WithoutINRL1-WDSSV5\logs


In [4]:
from network.dataset import *

train_dataset, val_dataset, test_dataset = WDSSDatasetCompressed.get_datasets(settings)

In [5]:
from network.modules import *

class WDSSModelV5(nn.Module):
    def __init__(self):
        super(WDSSModelV5, self).__init__()

        self.lr_frame_feature_extractor = LRFrameFeatureExtractor(12, 64, [32, 48, 48])
        self.hr_gbuffer_feature_extractor = HRGBufferFeatureExtractor(44, 5, 64)
        self.feature_fusion = FeatureFusion(64, 12, [64, 48])
        self.inr = FourierMappedINR(lr_feat_c=32, gb_feat_c=32, out_channels=12, mlp_inp_channels=64, hidden_channels=[64, 64, 64])

        self.final_conv = nn.Conv2d(12, 12, 3, 1, 1)

    def forward(self, lr_frame: torch.Tensor, hr_gbuffer: torch.Tensor, temporal: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Pixel unshuffling
        lr_frame_pixel_unshuffled = F.pixel_unshuffle(lr_frame, 2)
        hr_gbuffers_pixel_unshuffled = F.pixel_unshuffle(hr_gbuffer, 2)

        # Extract features
        lr_frame_feature = self.lr_frame_feature_extractor(lr_frame_pixel_unshuffled)
        hr_gbuffer_feature = self.hr_gbuffer_feature_extractor(hr_gbuffers_pixel_unshuffled)

        # Split the lr_frame_feature and hr_gbuffer_feature into 2-2 parts
        lr_frame_ff, lr_frame_inr = torch.split(lr_frame_feature, 32, dim=1)
        hr_gbuffer_ff, hr_gbuffer_inr = torch.split(hr_gbuffer_feature, 32, dim=1)

        # Upsample the lr_frame_ff
        lr_frame_ff = ImageUtils.upsample(lr_frame_ff, 2)

        # Feature Fusion, input is concatenated lr_frame_ff and hr_gbuffer_ff
        feature_fusion = self.feature_fusion(torch.cat([lr_frame_ff, hr_gbuffer_ff], dim=1))

        # INR
        inr = self.inr(lr_frame_inr, hr_gbuffer_inr)

        # Element-wise addition
        out = feature_fusion + inr
        # out = inr

        # Final Convolution
        out = self.final_conv(out)

        return out, WaveletProcessor.batch_iwt(out)
    
# Model
model = WDSSModelV5().to(device)

In [6]:
from network.losses import CriterionSSIM_L1

criterion = CriterionSSIM_L1().to(device=device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=5, gamma=0.5)

Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]




Loading model from: c:\Dev\MinorProject\WDSS\.venv\Lib\site-packages\lpips\weights\v0.1\vgg.pth


  self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)


In [None]:
from network.trainer import Trainer

trainer = Trainer(settings, model, optimizer, scheduler, criterion, train_dataset, val_dataset, test_dataset)

In [None]:
trainer.train(num_epochs=100)