In [1]:
# I first mount my google colab
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [26]:
!pwd

/content/drive/MyDrive/6.8300 Final Project/cvproject


In [2]:
# I then manage path to make path-dependent commands simpler
import os
path = "/content/drive/MyDrive/6.8300 Final Project/cvproject/"


os.chdir(path)
os.listdir(path)

['val',
 'Copy of hello.avi',
 '.ipynb_checkpoints',
 'network_swinir.py',
 'train_sharp.zip',
 'train_sharp_bicubic.zip',
 'uc?id=1a4PrjqT-hShvY9IyJm3sPF0ZaXyrCozR',
 'util_calculate_psnr_ssim.py',
 'val_sharp.zip',
 'val_sharp_bicubic.zip',
 'video_dataset.py',
 'vimeo_super_resolution_test.zip',
 'train',
 'train_blurry',
 'val_blurry',
 '__pycache__',
 'SwinIR',
 'hello.avi',
 'imageSRModel.pth',
 'vanilla_sr_new.avi']

In [3]:
!pip install Pillow
!pip install -U image
!pip install opencv-python
!pip install tqdm
!pip install torch
!pip install torchvision

from tqdm import tqdm
from io import BytesIO

import cv2
import numpy as np
import PIL.Image
from IPython.display import Image, clear_output, display

# PyTorch will be out main tool for playing with neural networks
import torch
import torch.hub
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, datasets, transforms

# For reproducibility
torch.manual_seed(1234)

# CPU / GPU
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda:0')

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting image
  Downloading image-1.5.33.tar.gz (15 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting django (from image)
  Downloading Django-4.2.1-py3-none-any.whl (8.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.0/8.0 MB[0m [31m80.0 MB/s[0m eta [36m0:00:00[0m
Collecting asgiref<4,>=3.6.0 (from django->image)
  Downloading asgiref-3.6.0-py3-none-any.whl (23 kB)
Building wheels for collected packages: image
  Building wheel for image (setup.py) ... [?25l[?25hdone
  Created wheel for image: filename=image-1.5.33-py2.py3-none-any.whl size=19483 sha256=d42bef6baa104270d80da777c2351ef9c3fd3c0bab17bce418834eeccf8e3e19
  Stored in directory: /root/.cache/pip/wheels/70/0c/a4/7cfa53a5c6225c2db2bfec08e782b43d0f25fdae2e995b69be
Successfully built 

#Loss Functions and Metrics

In [4]:
device

device(type='cuda', index=0)

#Datasets

In [5]:
# we use a library by RaivoKoot to make the video dataset easier
# !git clone https://github.com/RaivoKoot/Video-Dataset-Loading-Pytorch.git

# and move video_dataset.py to the main project folder to avoid hassle with import in python!

# now let's generate annotation.txt that is necessary for the library to function

def format_seq_num(number, total_digits):
    return ("0" * (total_digits - len(str(number)))) + str(number)

num_samples = [240, 30]
name_folder = ["train", "val"]

# for sharp/GT videos
annotation_content = []
for num_sample, name in zip(num_samples, name_folder):
    sharp_annotation_content = []
    for i in range(num_sample):
        row = format_seq_num(i, 3) + " 0 99 " + str(i) + "\n"
        sharp_annotation_content.append(row)
    annotation_content.append(sharp_annotation_content)

with open("train/annotation.txt", "w") as annotation:
    annotation.writelines(annotation_content[0])

with open("train_blurry/train_sharp_bicubic/annotation.txt", "w") as annotation:
    annotation.writelines(annotation_content[0])

with open("val/annotation.txt", "w") as annotation:
    annotation.writelines(annotation_content[1])

with open("val_blurry/val_sharp_bicubic/annotation.txt", "w") as annotation:
    annotation.writelines(annotation_content[1])

In [6]:
# import first before creating dataset
from video_dataset import VideoFrameDataset, ImglistToTensor
from torchvision import transforms
import torch
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import os

import tqdm
import matplotlib
from matplotlib import animation, rc

import cv2

%matplotlib inline

In [7]:
# Parameters
NUM_FRAMES = 100

In [8]:
sharp_train_root = os.path.join(path, "train")
sharp_train_annotation_file = os.path.join(sharp_train_root, 'annotation.txt')
sharp_train_root = os.path.join(sharp_train_root, "train_sharp")

blurry_train_root = os.path.join(path, "train_blurry", "train_sharp_bicubic")
blurry_train_annotation_file = os.path.join(blurry_train_root, "annotation.txt")
blurry_train_root = os.path.join(blurry_train_root, "X4")

sharp_val_root = os.path.join(path, "val")
sharp_val_annotation_file = os.path.join(sharp_val_root, 'annotation.txt')
sharp_val_root = os.path.join(sharp_val_root, "val_sharp")

blurry_val_root = os.path.join(path, "val_blurry", "val_sharp_bicubic")
blurry_val_annotation_file = os.path.join(blurry_val_root, "annotation.txt")
blurry_val_root = os.path.join(blurry_val_root, "X4")

In [9]:
preprocess = transforms.Compose([
    ImglistToTensor(),  # list of PIL images to (FRAMES x CHANNELS x HEIGHT x WIDTH) tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # standard normalization
])

def denormalize(video_tensor):
    """
    Undoes mean/standard deviation normalization, zero to one scaling,
    and channel rearrangement for a batch of images.
    args:
        video_tensor: a (FRAMES x CHANNELS x HEIGHT x WIDTH) tensor
    """
    inverse_normalize = transforms.Normalize(
        mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
        std=[1 / 0.229, 1 / 0.224, 1 / 0.225]
    )
    return (inverse_normalize(video_tensor) * 255.).type(torch.uint8).permute(0, 2, 3, 1).numpy()

In [10]:
# might want to lower the resolution for better speed

sharp_train_dataset = VideoFrameDataset(
    root_path = sharp_train_root,
    annotationfile_path=sharp_train_annotation_file,
    num_segments=NUM_FRAMES,
    frames_per_segment=1,
    imagefile_template="{:08d}.png",
    transform=preprocess,
    test_mode=False
)

sample = sharp_train_dataset[3]
frames = sample[0]  # list of PIL images
label = sample[1]  # integer label

In [11]:
# might want to lower the resolution for better speed

blurry_train_dataset = VideoFrameDataset(
    root_path = blurry_train_root,
    annotationfile_path=blurry_train_annotation_file,
    num_segments=NUM_FRAMES,
    frames_per_segment=1,
    imagefile_template="{:08d}.png",
    transform=preprocess,
    test_mode=False
)

sample = blurry_train_dataset[3]
frames = sample[0]  # list of PIL images
label = sample[1]  # integer label

In [12]:
# might want to lower the resolution for better speed

sharp_val_dataset = VideoFrameDataset(
    root_path = sharp_val_root,
    annotationfile_path=sharp_val_annotation_file,
    num_segments=NUM_FRAMES,
    frames_per_segment=1,
    imagefile_template="{:08d}.png",
    transform=preprocess,
    test_mode=False
)

sample = sharp_val_dataset[3]
frames = sample[0]  # list of PIL images
label = sample[1]  # integer label

In [13]:
# might want to lower the resolution for better speed

blurry_val_dataset = VideoFrameDataset(
    root_path = blurry_val_root,
    annotationfile_path=blurry_val_annotation_file,
    num_segments=NUM_FRAMES,
    frames_per_segment=1,
    imagefile_template="{:08d}.png",
    transform=preprocess,
    test_mode=False
)

sample = blurry_val_dataset[3]
frames = sample[0]  # list of PIL images
label = sample[1]  # integer label

In [14]:
# convert video from BGR to RGB

def convertVideoBGRtoRGB(video):
    return video[:, :, :, [2, 1, 0]]

def get_video_output(normalized_video):
    return convertVideoBGRtoRGB(denormalize(normalized_video))

In [24]:
# frames = sample[0]
# # frames = [cv2.cvtColor(np.array(frame), cv2.COLOR_BGR2RGB) for frame in frames]
# # frames = np.array(frames)
# frames = denormalize(frames)
# frames = convertVideoBGRtoRGB(frames)

# print(frames.shape)


# # test_video = AnimationWrapper(rows=1, cols=1, frames=frames)
# # test_video.generate()
# # test_video.anim

# _, height, width, layers = frames.shape

# video = cv2.VideoWriter("hello.avi", 0, 24, (width,height))

# for image in frames:
#     video.write(image)
# video.release()
# cv2.destroyAllWindows()

(100, 720, 1280, 3)


# SwinIR: Image Superresolution Model

In [25]:
!git clone https://github.com/JingyunLiang/SwinIR.git

fatal: destination path 'SwinIR' already exists and is not an empty directory.


In [15]:
!pip install timm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting timm
  Downloading timm-0.6.13-py3-none-any.whl (549 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m549.1/549.1 kB[0m [31m18.8 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub (from timm)
  Downloading huggingface_hub-0.14.1-py3-none-any.whl (224 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m26.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: huggingface-hub, timm
Successfully installed huggingface-hub-0.14.1 timm-0.6.13


In [16]:
import argparse
import cv2
import glob
import numpy as np
from collections import OrderedDict
import os
import torch
import requests

from network_swinir import SwinIR as net
import util_calculate_psnr_ssim as util

In [17]:
# IMAGE SR model-specific parameter
SCALE_FACTOR = 4
MODEL_PATH = "model_zoo/swinir/002_lightweightSR_DIV2K_s64w8_SwinIR-S_x4.pth"
ARG_TASK = "lightweight_sr"
WINDOW_SIZE = 8
BORDER = SCALE_FACTOR
# python main_test_swinir.py --task lightweight_sr --scale 4 
# --model_path model_zoo/swinir/002_lightweightSR_DIV2K_s64w8_SwinIR-S_x4.pth
# --folder_lq testsets/Set5/LR_bicubic/X4 --folder_gt testsets/Set5/HR


In [18]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# set up model
url = 'https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/{}'.format(os.path.basename(MODEL_PATH))
r = requests.get(url, allow_redirects=True)
print(f'downloading model')
open(os.path.join(os.getcwd(), "imageSRModel.pth"), 'wb').write(r.content)

model = net(upscale=SCALE_FACTOR, in_chans=3, img_size=64, window_size=8,
                    img_range=1., depths=[6, 6, 6, 6], embed_dim=60, num_heads=[6, 6, 6, 6],
                    mlp_ratio=2, upsampler='pixelshuffledirect', resi_connection='1conv')
param_key_g = 'params'

pretrained_model = torch.load("imageSRModel.pth")
model.load_state_dict(pretrained_model[param_key_g] if param_key_g in pretrained_model.keys() else pretrained_model, strict=True)

model.eval()
model = model.to(device)

downloading model


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [35]:
test_results = OrderedDict()
test_results['psnr'] = []
test_results['ssim'] = []
test_results['psnr_y'] = []
test_results['ssim_y'] = []
test_results['psnrb'] = []
test_results['psnrb_y'] = []
psnr, ssim, psnr_y, ssim_y, psnrb, psnrb_y = 0, 0, 0, 0, 0, 0

vanilla_sr_video = []
lr_video = blurry_train_dataset[3][0] # 3rd data point in the blurry training version
lr_video = get_video_output(lr_video)
# lr_video = [cv2.cvtColor(np.array(frame), cv2.COLOR_BGR2RGB) for frame in lr_video]
gt_video = sharp_train_dataset[3][0] # 3rd data point in the sharp training version
gt_video = get_video_output(gt_video)
# gt_video = [cv2.cvtColor(np.array(frame), cv2.COLOR_BGR2RGB) for frame in gt_video]

lr_video = [frame / 255 for frame in lr_video]
gt_video = [frame / 255 for frame in gt_video]

for idx, (img_lq, img_gt) in enumerate(zip(lr_video, gt_video)):
    # read image
    img_lq = np.transpose(img_lq if img_lq.shape[2] == 1 else img_lq[:, :, [2, 1, 0]], (2, 0, 1))  # HCW-BGR to CHW-RGB
    img_lq = torch.from_numpy(img_lq).float().unsqueeze(0).to(device)  # CHW-RGB to NCHW-RGB

    # inference
    with torch.no_grad():
        # pad input image to be a multiple of window_size
        _, _, h_old, w_old = img_lq.size()
        h_pad = (h_old // WINDOW_SIZE + 1) * WINDOW_SIZE - h_old
        w_pad = (w_old // WINDOW_SIZE + 1) * WINDOW_SIZE - w_old
        img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, :h_old + h_pad, :]
        img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
        output = model(img_lq)
        output = output[..., :h_old * SCALE_FACTOR, :w_old * SCALE_FACTOR]

    # save image
    output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
    if output.ndim == 3:
        output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))  # CHW-RGB to HCW-BGR
    output = (output * 255.0).round().astype(np.uint8)  # float32 to uint8
    
    # ----------------------------------------------------
    # appending to later make a video
    vanilla_sr_video.append(output)

    # ----------------------------------------------------

    # evaluate psnr/ssim/psnr_b
    if img_gt is not None:
        img_gt = (img_gt * 255.0).round().astype(np.uint8)  # float32 to uint8
        img_gt = img_gt[:h_old * SCALE_FACTOR, :w_old * SCALE_FACTOR, ...]  # crop gt
        img_gt = np.squeeze(img_gt)

        psnr = util.calculate_psnr(output, img_gt, crop_border=BORDER)
        ssim = util.calculate_ssim(output, img_gt, crop_border=BORDER)
        test_results['psnr'].append(psnr)
        test_results['ssim'].append(ssim)
        if img_gt.ndim == 3:  # RGB image
            psnr_y = util.calculate_psnr(output, img_gt, crop_border=BORDER, test_y_channel=True)
            ssim_y = util.calculate_ssim(output, img_gt, crop_border=BORDER, test_y_channel=True)
            test_results['psnr_y'].append(psnr_y)
            test_results['ssim_y'].append(ssim_y)
        if ARG_TASK in ['jpeg_car', 'color_jpeg_car']:
            psnrb = util.calculate_psnrb(output, img_gt, crop_border=BORDER, test_y_channel=False)
            test_results['psnrb'].append(psnrb)
            if ARG_TASK in ['color_jpeg_car']:
                psnrb_y = util.calculate_psnrb(output, img_gt, crop_border=BORDER, test_y_channel=True)
                test_results['psnrb_y'].append(psnrb_y)
        print('Testing {:d} {:20s} - PSNR: {:.2f} dB; SSIM: {:.4f}; PSNRB: {:.2f} dB;'
                'PSNR_Y: {:.2f} dB; SSIM_Y: {:.4f}; PSNRB_Y: {:.2f} dB.'.
                format(idx, "imgname", psnr, ssim, psnrb, psnr_y, ssim_y, psnrb_y))
    else:
        print('Testing {:d} {:20s}'.format(idx, "imgname"))

# summarize psnr/ssim
if img_gt is not None:
    ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
    ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
    print('\n-- Average PSNR/SSIM(RGB): {:.2f} dB; {:.4f}'.format(ave_psnr, ave_ssim))
    if img_gt.ndim == 3:
        ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y'])
        ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y'])
        print('-- Average PSNR_Y/SSIM_Y: {:.2f} dB; {:.4f}'.format(ave_psnr_y, ave_ssim_y))
    if ARG_TASK in ['jpeg_car', 'color_jpeg_car']:
        ave_psnrb = sum(test_results['psnrb']) / len(test_results['psnrb'])
        print('-- Average PSNRB: {:.2f} dB'.format(ave_psnrb))
        if ARG_TASK in ['color_jpeg_car']:
            ave_psnrb_y = sum(test_results['psnrb_y']) / len(test_results['psnrb_y'])
            print('-- Average PSNRB_Y: {:.2f} dB'.format(ave_psnrb_y))

Testing 0 imgname              - PSNR: 27.66 dB; SSIM: 0.7275; PSNRB: 0.00 dB;PSNR_Y: 29.02 dB; SSIM_Y: 0.7539; PSNRB_Y: 0.00 dB.
Testing 1 imgname              - PSNR: 27.48 dB; SSIM: 0.7242; PSNRB: 0.00 dB;PSNR_Y: 28.83 dB; SSIM_Y: 0.7502; PSNRB_Y: 0.00 dB.
Testing 2 imgname              - PSNR: 27.48 dB; SSIM: 0.7216; PSNRB: 0.00 dB;PSNR_Y: 28.84 dB; SSIM_Y: 0.7476; PSNRB_Y: 0.00 dB.
Testing 3 imgname              - PSNR: 27.45 dB; SSIM: 0.7202; PSNRB: 0.00 dB;PSNR_Y: 28.81 dB; SSIM_Y: 0.7465; PSNRB_Y: 0.00 dB.
Testing 4 imgname              - PSNR: 27.57 dB; SSIM: 0.7252; PSNRB: 0.00 dB;PSNR_Y: 28.93 dB; SSIM_Y: 0.7513; PSNRB_Y: 0.00 dB.
Testing 5 imgname              - PSNR: 27.52 dB; SSIM: 0.7237; PSNRB: 0.00 dB;PSNR_Y: 28.87 dB; SSIM_Y: 0.7499; PSNRB_Y: 0.00 dB.
Testing 6 imgname              - PSNR: 27.56 dB; SSIM: 0.7319; PSNRB: 0.00 dB;PSNR_Y: 28.91 dB; SSIM_Y: 0.7575; PSNRB_Y: 0.00 dB.
Testing 7 imgname              - PSNR: 27.46 dB; SSIM: 0.7323; PSNRB: 0.00 dB;PSNR_Y: 28.8

In [36]:
# frames = [cv2.cvtColor(np.array(frame), cv2.COLOR_BGR2RGB) for frame in frames]
vanilla_sr_video_new = np.array(vanilla_sr_video)

_, height, width, layers = vanilla_sr_video_new.shape

video = cv2.VideoWriter("vanilla_sr_new.avi", 0, 24, (width,height))

for image in vanilla_sr_video_new:
    video.write(image)
video.release()
cv2.destroyAllWindows()

In [19]:
# let's actually make dataloaders
BATCH_SIZE = 2
NUM_WORKERS = 2

train_blurry_dataloader = torch.utils.data.DataLoader(
    dataset=blurry_train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

val_blurry_dataloader = torch.utils.data.DataLoader(
    dataset=blurry_val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)


In [20]:
len(train_blurry_dataloader)

120

# Model Structure

In [21]:
device

device(type='cuda')

In [43]:
# model for improved video superresolution
import math
class TrNet(nn.Module):
    def __init__(
        self,
        kernel_size=5,
        filter1_size=5,
        filter2_size=16,
    ):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels=3,
            out_channels=filter1_size,
            kernel_size=kernel_size,
            padding=math.floor(kernel_size / 2)
        )
        self.conv2 = nn.Conv2d(
            in_channels=filter1_size,
            out_channels=3,
            kernel_size=kernel_size,
            padding=math.floor(kernel_size / 2)
        )
        # self.fc1 = nn.Linear(in_features=filter2_size * kernel_size ** 2, out_features=fc1_size)
        # self.fc2 = nn.Linear(in_features=fc1_size, out_features=fc2_size)
        # self.fc3 = nn.Linear(in_features=image_size + fc2_size, out_features=image_size)

    def forward(self, x):
        # assume x is (frames, color channels, height, width) tensor
        out = torch.diff(x, dim=0, prepend=torch.unsqueeze(x[0], 0))

        out = F.relu(self.conv1(out))
        out = F.relu(self.conv2(out))
        return out

custom_model = TrNet().to(device)

In [44]:
# testing the input
image1 = frames
print(image1.shape)
print()
image1 = image1.to(device)
custom_model(image1).shape

torch.Size([100, 3, 180, 320])



torch.Size([100, 3, 180, 320])

In [23]:
vanilla_sr_video = []
lr_video = blurry_train_dataset[3][0] # 3rd data point in the blurry training version
lr_video = get_video_output(lr_video)
# lr_video = [cv2.cvtColor(np.array(frame), cv2.COLOR_BGR2RGB) for frame in lr_video]
gt_video = sharp_train_dataset[3][0] # 3rd data point in the sharp training version
gt_video = get_video_output(gt_video)
# gt_video = [cv2.cvtColor(np.array(frame), cv2.COLOR_BGR2RGB) for frame in gt_video]

lr_video = [frame / 255 for frame in lr_video]
gt_video = [frame / 255 for frame in gt_video]

for idx, (img_lq, img_gt) in enumerate(zip(lr_video, gt_video)):
    # read image
    img_lq = np.transpose(img_lq if img_lq.shape[2] == 1 else img_lq[:, :, [2, 1, 0]], (2, 0, 1))  # HCW-BGR to CHW-RGB
    img_lq = torch.from_numpy(img_lq).float().unsqueeze(0).to(device)  # CHW-RGB to NCHW-RGB

    # inference
    with torch.no_grad():
        # pad input image to be a multiple of window_size
        _, _, h_old, w_old = img_lq.size()
        h_pad = (h_old // WINDOW_SIZE + 1) * WINDOW_SIZE - h_old
        w_pad = (w_old // WINDOW_SIZE + 1) * WINDOW_SIZE - w_old
        img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, :h_old + h_pad, :]
        img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
        output = model(img_lq)
        output = output[..., :h_old * SCALE_FACTOR, :w_old * SCALE_FACTOR]

    # save image
    output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
    if output.ndim == 3:
        output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))  # CHW-RGB to HCW-BGR
    output = (output * 255.0).round().astype(np.uint8)  # float32 to uint8
    
    # ----------------------------------------------------
    # appending to later make a video
    vanilla_sr_video.append(output)

torch.Size([100, 3, 180, 320])


In [46]:
image_sr_model = model

In [47]:
def train(dataloader, gt_data, model, optimizer, epoch):
    model.train()
    train_loss = []

    batches = tqdm(enumerate(dataloader), total=len(dataloader))
    batches.set_description("Epoch NA: Loss (NA) Accuracy (NA %)")
    for batch_idx, (data, target) in batches:
        video = torch.tensor([]).to(device)
        # Move data to appropriate device
        moved_data = data.to(device)
        moved_data = get_video_output(moved_data) / 255

        target = gt_data[target][0]
        target = target.to(device)
        target = get_video_output(target) / 255

        for idx, (img_lq, img_gt) in enumerate(zip(moved_data, target)):
            # read image
            img_lq = np.transpose(img_lq if img_lq.shape[2] == 1 else img_lq[:, :, [2, 1, 0]], (2, 0, 1))  # HCW-BGR to CHW-RGB
            img_lq = torch.from_numpy(img_lq).float().unsqueeze(0).to(device)  # CHW-RGB to NCHW-RGB

            # inference
            with torch.no_grad():
                # pad input image to be a multiple of window_size
                _, _, h_old, w_old = img_lq.size()
                h_pad = (h_old // WINDOW_SIZE + 1) * WINDOW_SIZE - h_old
                w_pad = (w_old // WINDOW_SIZE + 1) * WINDOW_SIZE - w_old
                img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, :h_old + h_pad, :]
                img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
                output = model(img_lq)
                output = output[..., :h_old * SCALE_FACTOR, :w_old * SCALE_FACTOR]

            # save image
            output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
            if output.ndim == 3:
                output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))  # CHW-RGB to HCW-BGR
            
            # ----------------------------------------------------
            # appending to later make a video
            video.append(output)

        # Zero out gradients
        optimizer.zero_grad()
        # Compute forward pass, loss, and gradients
        transitions = model(moved_data)
        loss = F.mse_loss(video + transitions, target)
        loss.backward()
        train_loss.append(loss)
        # Update parameters
        optimizer.step()
        # Compute and record accuracy

        batches.set_description(
            "Epoch {:d}: Loss ({:.2e})".format(
                epoch, loss.item()
            )
        )

    return train_loss

In [None]:
test_results = OrderedDict()
test_results['psnr'] = []
test_results['ssim'] = []
test_results['psnr_y'] = []
test_results['ssim_y'] = []
test_results['psnrb'] = []
test_results['psnrb_y'] = []
psnr, ssim, psnr_y, ssim_y, psnrb, psnrb_y = 0, 0, 0, 0, 0, 0

vanilla_sr_video = []
lr_video = blurry_train_dataset[3][0] # 3rd data point in the blurry training version
lr_video = get_video_output(lr_video)
# lr_video = [cv2.cvtColor(np.array(frame), cv2.COLOR_BGR2RGB) for frame in lr_video]
gt_video = sharp_train_dataset[3][0] # 3rd data point in the sharp training version
gt_video = get_video_output(gt_video)
# gt_video = [cv2.cvtColor(np.array(frame), cv2.COLOR_BGR2RGB) for frame in gt_video]

lr_video = [frame / 255 for frame in lr_video]
gt_video = [frame / 255 for frame in gt_video]

for idx, (img_lq, img_gt) in enumerate(zip(lr_video, gt_video)):
    # read image
    img_lq = np.transpose(img_lq if img_lq.shape[2] == 1 else img_lq[:, :, [2, 1, 0]], (2, 0, 1))  # HCW-BGR to CHW-RGB
    img_lq = torch.from_numpy(img_lq).float().unsqueeze(0).to(device)  # CHW-RGB to NCHW-RGB

    # inference
    with torch.no_grad():
        # pad input image to be a multiple of window_size
        _, _, h_old, w_old = img_lq.size()
        h_pad = (h_old // WINDOW_SIZE + 1) * WINDOW_SIZE - h_old
        w_pad = (w_old // WINDOW_SIZE + 1) * WINDOW_SIZE - w_old
        img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, :h_old + h_pad, :]
        img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
        output = model(img_lq)
        output = output[..., :h_old * SCALE_FACTOR, :w_old * SCALE_FACTOR]

    # save image
    output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
    if output.ndim == 3:
        output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))  # CHW-RGB to HCW-BGR
    output = (output * 255.0).round().astype(np.uint8)  # float32 to uint8
    
    # ----------------------------------------------------
    # appending to later make a video
    vanilla_sr_video.append(output)

    # ----------------------------------------------------

    # evaluate psnr/ssim/psnr_b
    if img_gt is not None:
        img_gt = (img_gt * 255.0).round().astype(np.uint8)  # float32 to uint8
        img_gt = img_gt[:h_old * SCALE_FACTOR, :w_old * SCALE_FACTOR, ...]  # crop gt
        img_gt = np.squeeze(img_gt)

        psnr = util.calculate_psnr(output, img_gt, crop_border=BORDER)
        ssim = util.calculate_ssim(output, img_gt, crop_border=BORDER)
        test_results['psnr'].append(psnr)
        test_results['ssim'].append(ssim)
        if img_gt.ndim == 3:  # RGB image
            psnr_y = util.calculate_psnr(output, img_gt, crop_border=BORDER, test_y_channel=True)
            ssim_y = util.calculate_ssim(output, img_gt, crop_border=BORDER, test_y_channel=True)
            test_results['psnr_y'].append(psnr_y)
            test_results['ssim_y'].append(ssim_y)
        if ARG_TASK in ['jpeg_car', 'color_jpeg_car']:
            psnrb = util.calculate_psnrb(output, img_gt, crop_border=BORDER, test_y_channel=False)
            test_results['psnrb'].append(psnrb)
            if ARG_TASK in ['color_jpeg_car']:
                psnrb_y = util.calculate_psnrb(output, img_gt, crop_border=BORDER, test_y_channel=True)
                test_results['psnrb_y'].append(psnrb_y)
        print('Testing {:d} {:20s} - PSNR: {:.2f} dB; SSIM: {:.4f}; PSNRB: {:.2f} dB;'
                'PSNR_Y: {:.2f} dB; SSIM_Y: {:.4f}; PSNRB_Y: {:.2f} dB.'.
                format(idx, "imgname", psnr, ssim, psnrb, psnr_y, ssim_y, psnrb_y))
    else:
        print('Testing {:d} {:20s}'.format(idx, "imgname"))

In [48]:

def evaluate(dataloader, gt_data, model):
    test_results = OrderedDict()
    test_results['psnr'] = []
    test_results['ssim'] = []
    test_results['psnr_y'] = []
    test_results['ssim_y'] = []
    test_results['psnrb'] = []
    test_results['psnrb_y'] = []
    psnr, ssim, psnr_y, ssim_y, psnrb, psnrb_y = 0, 0, 0, 0, 0, 0
    model.eval()

    for video, video_idx in dataloader:
        gt_video = gt_data[video_idx][0].to(device)
        video = video.to(device)

        gt_video = get_video_output(gt_video) / 255
        video = get_video_output(video) / 255

        for idx, (img_lq, img_gt) in enumerate(zip(video, gt_video)):
            # read image
            img_lq = np.transpose(img_lq if img_lq.shape[2] == 1 else img_lq[:, :, [2, 1, 0]], (2, 0, 1))  # HCW-BGR to CHW-RGB
            img_lq = torch.from_numpy(img_lq).float().unsqueeze(0).to(device)  # CHW-RGB to NCHW-RGB

            # inference
            with torch.no_grad():
                # pad input image to be a multiple of window_size
                _, _, h_old, w_old = img_lq.size()
                h_pad = (h_old // WINDOW_SIZE + 1) * WINDOW_SIZE - h_old
                w_pad = (w_old // WINDOW_SIZE + 1) * WINDOW_SIZE - w_old
                img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, :h_old + h_pad, :]
                img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
                output = model(img_lq)
                output = output[..., :h_old * SCALE_FACTOR, :w_old * SCALE_FACTOR]

            # save image
            output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
            if output.ndim == 3:
                output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))  # CHW-RGB to HCW-BGR
            output = (output * 255.0).round().astype(np.uint8)  # float32 to uint8

            # evaluate psnr/ssim/psnr_b
            if img_gt is not None:
                img_gt = (img_gt * 255.0).round().astype(np.uint8)  # float32 to uint8
                img_gt = img_gt[:h_old * SCALE_FACTOR, :w_old * SCALE_FACTOR, ...]  # crop gt
                img_gt = np.squeeze(img_gt)

                psnr = util.calculate_psnr(output, img_gt, crop_border=BORDER)
                ssim = util.calculate_ssim(output, img_gt, crop_border=BORDER)
                test_results['psnr'].append(psnr)
                test_results['ssim'].append(ssim)
                if img_gt.ndim == 3:  # RGB image
                    psnr_y = util.calculate_psnr(output, img_gt, crop_border=BORDER, test_y_channel=True)
                    ssim_y = util.calculate_ssim(output, img_gt, crop_border=BORDER, test_y_channel=True)
                    test_results['psnr_y'].append(psnr_y)
                    test_results['ssim_y'].append(ssim_y)
                if ARG_TASK in ['jpeg_car', 'color_jpeg_car']:
                    psnrb = util.calculate_psnrb(output, img_gt, crop_border=BORDER, test_y_channel=False)
                    test_results['psnrb'].append(psnrb)
                    if ARG_TASK in ['color_jpeg_car']:
                        psnrb_y = util.calculate_psnrb(output, img_gt, crop_border=BORDER, test_y_channel=True)
                        test_results['psnrb_y'].append(psnrb_y)
                print('Testing {:d} {:20s} - PSNR: {:.2f} dB; SSIM: {:.4f}; PSNRB: {:.2f} dB;'
                        'PSNR_Y: {:.2f} dB; SSIM_Y: {:.4f}; PSNRB_Y: {:.2f} dB.'.
                        format(idx, "imgname", psnr, ssim, psnrb, psnr_y, ssim_y, psnrb_y))
            else:
                print('Testing {:d} {:20s}'.format(idx, "imgname"))
    ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
    ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
    return ave_psnr, ave_ssim

            

        

In [None]:
num_epochs = 10
lr = 0.001

def create_optimizer(net, lr):
    # TODO: Create optimizer
    return torch.optim.Adam(net.parameters(), lr=lr)

optim = create_optimizer(custom_model, lr)
train_loader = train_blurry_dataloader
val_loader = val_blurry_dataloader

for epoch in range(num_epochs):
    print('Epoch: {}\tValidation Accuracy: {:.4f}%'.format(epoch, evaluate(val_loader, sharp_val_dataset, custom_model) * 100))
    train(train_loader, sharp_train_dataset, custom_model, optim, epoch)