In [21]:
import os
# Display current working directory
print(os.getcwd())
# To make sure opencv imports .exr files
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
# If the current directory is not WDSS, then set it to one level up
if os.getcwd()[-4:] != 'WDSS':
    os.chdir('..')
print(os.getcwd())

d:\minorProject\WDSS
d:\minorProject\WDSS


In [22]:
import torch
import torch.nn as nn

from network.modules import *

from utils import *

from datetime import datetime

from typing import List, Dict, Any

from config import device, Settings

In [23]:
settings = Settings("config/config.json", "WDSSV5")

In [24]:
from network.dataset import *

train_dataset, val_dataset, test_dataset = WDSSDatasetCompressed.get_datasets(settings)

In [25]:
sum_lr_wavelet: bool = settings.model_config["sum_lr_wavelet"]
has_feature_fusion: bool = settings.model_config["has_feature_fusion"]
has_fminr: bool = settings.model_config["has_fminr"]
lr_feat_extractor_config: Dict[str, Any] = settings.model_config["lr_feat_extractor"]
temporal_feat_extractor_config: Dict[str, Any] = settings.model_config["temporal_feat_extractor"]
hr_gb_feat_extractor_config: Dict[str, Any] = settings.model_config["hr_gb_feat_extractor"]
feature_fusion_config: Dict[str, Any] = settings.model_config["feature_fusion"]
fminr_config: Dict[str, Any] = settings.model_config["fminr"]

In [26]:
lr_frame_feat_extractor = BaseLRFeatExtractor.from_config(lr_feat_extractor_config).to(device)
temporal_feat_extractor = BaseTemporalFeatExtractor.from_config(temporal_feat_extractor_config).to(device)
hr_gb_feat_extractor = BaseGBFeatExtractor.from_config(hr_gb_feat_extractor_config).to(device)

has_fminr = has_fminr
has_feature_fusion = has_feature_fusion
sum_lr_wavelet = sum_lr_wavelet

if has_fminr:
    fminr = get_fminr(fminr_config).to(device)
if has_feature_fusion:
    fusion = BaseFeatureFusion.from_config(feature_fusion_config).to(device)

final_conv = nn.Sequential(
    nn.Conv2d(12, 12, kernel_size=3, padding=1, stride=1)
).to(device)

In [27]:
image_no = 0
raw_frame  = test_dataset.get_inference_frame(image_no)

lr_frame = raw_frame[FrameGroup.LR.value].unsqueeze(0).to(device)
print(lr_frame.shape)

hr_gbuffer = raw_frame[FrameGroup.GB.value].unsqueeze(0).to(device)
print(hr_gbuffer.shape)

temporal = raw_frame[FrameGroup.TEMPORAL.value].unsqueeze(0).to(device)
print(temporal.shape)

upscale_factor = 2

torch.Size([1, 3, 360, 640])
torch.Size([1, 12, 720, 1280])
torch.Size([1, 8, 720, 1280])


In [28]:
# Pixel unshuffle
lr_frame_ps = F.pixel_unshuffle(lr_frame, 2)


In [29]:
hr_gbuffer_ps = F.pixel_unshuffle(hr_gbuffer, 2)

In [30]:
temporal_ps = F.pixel_unshuffle(temporal, 2)


In [31]:
# Extract features
lr_frame_feat = lr_frame_feat_extractor(lr_frame_ps)

In [32]:
temporal_feat = temporal_feat_extractor(temporal_ps)
print(temporal_feat_extractor)

TemporalFeatExtractor(
  (net): Sequential(
    (0): LightWeightGatedConv2D(
      (feature): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (gate): Sequential(
        (0): Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): Sigmoid()
      )
    )
    (1): LightWeightGatedConv2D(
      (feature): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (gate): Sequential(
        (0): Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): Sigmoid()
      )
    )
    (2): LightWeightGatedConv2D(
      (feature): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (gate): Sequential(
        (0): Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): Sigmoid()
      )
    )
  )
)


In [33]:

hr_gb_feat = hr_gb_feat_extractor(hr_gbuffer_ps)
print(hr_gb_feat_extractor)

GBFeatureExtractor(
  (net): Sequential(
    (0): Conv2d(48, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): ResBlock(
      (expand_conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
      (fea_conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (reduce_conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
    )
    (3): ReLU()
    (4): ResBlock(
      (expand_conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
      (fea_conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (reduce_conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
    )
    (5): ReLU()
    (6): ResBlock(
      (expand_conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
      (fea_conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (reduce_conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
    )
    (7): ReLU()
    (8): ResBlock(
      (expand_conv): Conv2d(64, 128,

In [34]:
if has_fminr and has_feature_fusion:    
    lr_ff, lr_inr = torch.split(lr_frame_feat, lr_frame_feat.shape[1]//2, dim=1)
    gb_ff, gb_inr = torch.split(hr_gb_feat, hr_gb_feat.shape[1]//2, dim=1)
elif has_feature_fusion:
    lr_ff = lr_frame_feat
    gb_ff = hr_gb_feat
else:
    lr_inr = lr_frame_feat
    gb_inr = hr_gb_feat


In [35]:
wavelet_out: torch.Tensor | None = None


In [36]:
if has_fminr:
    wavelet_out = fminr.forward(lr_inr, gb_inr, upscale_factor)



In [37]:
if has_feature_fusion:
    lr_ff_upsampled = ImageUtils.upsample(lr_ff, upscale_factor)
    ff_out = fusion(torch.cat([lr_ff_upsampled, gb_ff, temporal_feat], dim=1))
    wavelet_out = wavelet_out + ff_out if has_fminr else ff_out



In [38]:
if sum_lr_wavelet:
    lr_wavelet = WaveletProcessor.batch_wt(lr_frame)
    lr_wavelet_ups = ImageUtils.upsample(lr_wavelet, upscale_factor)
    wavelet_out = wavelet_out + lr_wavelet_ups


In [39]:

wavelet_out = final_conv(wavelet_out)



In [40]:
image = WaveletProcessor.batch_iwt(wavelet_out)