# Diffusion Model for Video Super Resolution

Inspiration gathered from:

https://github.com/CompVis/latent-diffusion

https://ar5iv.labs.arxiv.org/html/2311.15908

In [1]:
import torch as torch
import torchvision 
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [2]:
torchvision.__version__

'0.17.2+cu121'

In [4]:
# datasets with _sharp are the correct/ground truth images
# datasets with _blur_bicubic are those that have been blurred and
# downsampled using bicubic interpolation
datasets = ['train_sharp', 'train_blur_bicubic', 'val_sharp', 'val_blur_bicubic']
for set in datasets:
    print(set)
    if not os.path.isfile(f"REDS/{set}.zip"):
        # print("Downloading")
        cmdlet = f"python download_REDS.py --{set}"
        print(cmdlet)
        os.system(cmdlet)
# if not already downloaded, this will download all datasets (takes a while)
        

train_sharp
train_blur_bicubic
val_sharp
val_blur_bicubic


In [5]:
# Set up data into dataset and dataloader
# It assumes the project file structure as downloaded from above
# Built based on docs: https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
class REDS(Dataset):
    def __init__(self, train=True):
        self.type = 'train' if train else 'test' 
        if self.type == 'train':
            self.hr_dir = "REDS/train_sharp/train/train_sharp"
            self.lr_dir = "REDS/train_blur_bicubic/train/train_blur_bicubic/X4"
        else:
            self.hr_dir = "REDS/val_sharp/val/val_sharp"
            self.lr_dir = "REDS/val_blur_bicubic/val/val_blur_bicubic/X4"
            
    def __len__(self):
        return len(os.listdir(self.hr_dir)) # training size = 240 videos, testing size = 30 videos
            
    def __getitem__(self, idx):
        # each return gives a single HR frame with 5 corresponding LR frames
        # the middle LR frame (frame 3) will be the blurred/downsampled version of the HR frame
        # the 5 sequential LR frames will be chosen randomly from the given idx-video
        
        # Getting video sequence folder name
        if idx < 10:
            video = '00' + str(idx)
        elif idx < 100:
            video = '0' + str(idx)
        # Getting random sequence of 5 LR frames from the video    
        num_video_frames = len(os.listdir(f"{self.hr_dir}/000"))
        rand_frame_id = np.random.randint(2, num_video_frames - 2)
        lr_frame_idx = []
        for i in range(-2, 3):
            id_int = rand_frame_id + i
            if id_int < 10:
                id_str = '0000000' + str(id_int)
            elif id_int < 100:
                id_str = '000000' + str(id_int)
            else:
                id_str = '00000' + str(id_int)
            lr_frame_idx.append(id_str)
        # Actually reading in the images
        hr_frame = torchvision.io.read_image(f"{self.hr_dir}/{video}/{lr_frame_idx[2]}.png")
        lr_frames = []
        for v in lr_frame_idx:
            lr_frame = torchvision.io.read_image(f"{self.lr_dir}/{video}/{v}.png")
            lr_frames.append(lr_frame)
            
        return lr_frames, hr_frame

In [6]:
train_dataset = REDS(train=True)
test_dataset = REDS(train=False)
# Using the customizable PyTorch dataset allows us to use dataloaders, iterable objects for training/testing that
# make it so easy!
train_dataloader = DataLoader(train_dataset, batch_size=10, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=10, shuffle=True)

In [7]:
# Show an example of getting data from either of the datasets
# Top is high-res image, bottom is blurred/downsampled low-res image equivalent
lr_imgs, hr_img = train_dataset.__getitem__(25)
hr_img

tensor([[[154, 157, 159,  ...,  68,  68,  68],
         [158, 159, 160,  ...,  68,  68,  68],
         [162, 162, 162,  ...,  68,  68,  68],
         ...,
         [164, 164, 164,  ..., 134, 134, 134],
         [164, 164, 164,  ..., 134, 134, 134],
         [164, 164, 164,  ..., 134, 134, 134]],

        [[161, 162, 164,  ...,  59,  59,  59],
         [163, 164, 164,  ...,  59,  59,  59],
         [165, 165, 165,  ...,  59,  59,  59],
         ...,
         [ 61,  61,  61,  ..., 125, 125, 125],
         [ 61,  61,  61,  ..., 125, 125, 125],
         [ 61,  61,  61,  ..., 125, 125, 125]],

        [[162, 163, 165,  ...,  49,  49,  49],
         [164, 165, 166,  ...,  49,  49,  49],
         [167, 167, 167,  ...,  49,  49,  49],
         ...,
         [ 47,  47,  47,  ..., 108, 108, 108],
         [ 47,  47,  47,  ..., 108, 108, 108],
         [ 47,  47,  47,  ..., 108, 108, 108]]], dtype=torch.uint8)

In [2]:
plt.imshow(hr_img.permute(1,2,0))
plt.show()

NameError: name 'plt' is not defined

In [1]:
# Set up loss functions

# PSNR

# perceptual_loss
vgg = torchvision.models.vgg19(weights='VGG19_Weights.IMAGENET1K_V1').features # removes final classification layer as we don't need it
vgg.eval() # sets the model to evaluation mode, to not update weights
for param in vgg.parameters():
    param.requires_grad = False

NameError: name 'torchvision' is not defined

In [19]:
vgg(hr_img.float()).sum()

tensor(1077611.7500)

In [11]:
hr_img.size()

torch.Size([3, 720, 1280])