# Video Super Resolution

## By Daniel Shkreli and Victor Reyes

### Import what we need

In [2]:
import torch
import torch.nn as nn
import numpy as np
import time
import lpips
import torchvision
from torchvision.datasets.video_utils import VideoClips
from torchvision import datasets, models, transforms
!pip install tqdm
from tqdm.notebook import tqdm
import os
import copy

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print("Using the GPU!")
else:
    print("WARNING: Could not find GPU! Using CPU only")

Using the GPU!


### Import the dataset
To do this, we will be downloading the videos of the National Geographic channel.

In [5]:
sample_video_filename = "yt-dls/train/0/Secrets of the Whales _ Official Trailer _ Disney+.mp4"
sample_video = VideoClips([sample_video_filename], clip_length_in_frames=10, frames_between_clips=10, frame_rate=10)

  0%|          | 0/1 [00:00<?, ?it/s]

  "follow-up version. Please use pts_unit 'sec'.")


In [95]:
clip = sample_video.get_clip(0)[0]
frame_0 = clip[0]
print(frame_0.shape)
transformed_frame = torch.unsqueeze(data_transforms['train'](frame_0), 0)
print(transformed_frame.shape)
latent = torch.unsqueeze(torch.unsqueeze(resnet(transformed_frame),-1),-1)
print(latent.shape)
print(decoder(latent).shape)

  "follow-up version. Please use pts_unit 'sec'.")


torch.Size([720, 1280, 3])
torch.Size([1, 3, 256, 256])
torch.Size([1, 1024, 1, 1])
torch.Size([1, 3, 511, 511])


In [None]:
clips.shape
print(torch.__version__)

### Create the model
We will use resnet 18, pretrained at first.
from https://github.com/hsinyilin19/ResNetVAE


In [1]:
def initialize_resnet_model(resume_from = None):
    resnet = models.resnet18(pretrained=True)
    num_ftrs = resnet.fc.in_features
    print("num_ftrs:", num_ftrs)
    layers = list(resnet.children())
    #resnet = nn.Sequential(*layers[:-1])
    # TODO Freeze weights 
    resnet.fc = nn.Sequential(nn.Linear(num_ftrs, 1024), nn.BatchNorm1d(1024, momentum=0.01), nn.Linear(1024, 1024)) # create latent
    
    #resnet.add_module("bottleneck", nn.Sequential()
    if resume_from is not None:
        print(f"Loading weights from {resume_from}")
        model.load_state_dict(torch.load(resume_from))
    return resnet

def initialize_decoder():
    convTrans6 = nn.Sequential(
        nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=(7,7), stride=stride,
                           padding=(0,0)),
        nn.BatchNorm2d(512, momentum=0.01),
        nn.ReLU(inplace=True),
    )
    
    convTrans7 = nn.Sequential(
        nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=(3,3), stride=stride,
                           padding=(0,0)),
        nn.BatchNorm2d(256, momentum=0.01),
        nn.ReLU(inplace=True),
    )
    
    convTrans8 = nn.Sequential(
        nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=(3,3), stride=stride,
                           padding=(0,0)),
        nn.BatchNorm2d(128, momentum=0.01),
        nn.ReLU(inplace=True),
    )
    
    convTrans9 = nn.Sequential(
        nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=(3,3), stride=stride,
                           padding=(0,0)),
        nn.BatchNorm2d(64, momentum=0.01),
        nn.ReLU(inplace=True),
    )
   
    convTrans10 = nn.Sequential(
        nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=(3,3), stride=stride,
                           padding=(0,0)),
        nn.BatchNorm2d(64, momentum=0.01),
        nn.ReLU(inplace=True),
    )
    
    convTrans11 = nn.Sequential(
        nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=(3,3), stride=stride,
                           padding=(0,0)),
        nn.BatchNorm2d(32, momentum=0.01),
        nn.ReLU(inplace=True),
    )
    
    convTrans12 = nn.Sequential(
        nn.ConvTranspose2d(in_channels=32, out_channels=3, kernel_size=(3,3), stride=stride,
                           padding=(0,0)),
        nn.BatchNorm2d(3, momentum=0.01),
        nn.Sigmoid(),
    )
    return nn.Sequential(convTrans6,
                         convTrans7,
                         convTrans8,
                         convTrans9,
                         convTrans10,
                         convTrans11,
                         convTrans12) 

    
def video_loader(filename):
    clips = VideoClips([filename], clip_length_in_frames=10, frames_between_clips = 10, frame_rate=10)
    
    return clips.get_clip(clips.num_clips()//2)[0]# currently just return a clip 
input_size = 256
data_transforms = {
    'train': transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        #transforms.RandomHorizontalFlip(),
        #transforms.RandomRotation([-30,30]),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

def get_dataloaders(batch_size, input_size = 256, shuffle = True):
    train_dataset = datasets.DatasetFolder(root="yt-dls/train/", loader=video_loader, extensions=("mp4"))
    val_dataset = datasets.DatasetFolder(root="yt-dls/val/", loader=video_loader, extensions=("mp4"))
    test_dataset = datasets.DatasetFolder(root="yt-dls/test/", loader=video_loader, extensions=("mp4"))

    dataloaders_dict = {'train': torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True),
                       'val': torch.utils.data.DataLoader(val_dataset, batch_size = batch_size, shuffle = True),
                       'test': torch.utils.data.DataLoader(test_dataset, batch_size = batch_size, shuffle = True)}
    return dataloaders_dict

NameError: name 'transforms' is not defined

In [None]:
def train(model, dataloaders, num_epochs=25):
    since = time.time()

    val_acc_history = []
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0

            # Iterate over data.
            # TQDM has nice progress bars
            for inputs, _ in tqdm(dataloaders[phase]):
                #inputs2 = data_transforms['train'](inputs)
                # we get 10, 720, 1280, 3
                # for now, lets just get a single frame
                # perform transform on input
                transformed_input = data_transforms[phase](inputs)
                label = inputs # this makes sense, should we perform any transforms?
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == 'train'):
                    prediction = model(inputs)
                    preceptual_loss = percept(prediction, label)
                    mse_loss = nn.mse(prediction, label)
                    loss = alpha * perceptual_loss + beta * mse_loss
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                print(inputs2.shape)
                

In [93]:
resnet = initialize_resnet_model().eval()
decoder = initialize_decoder().eval()
print(resnet)
print(decoder)

num_ftrs: 512
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(in

In [None]:
dataloaders_dict = get_dataloaders(1)
print(dataloaders_dict)

for inputs, _ in tqdm(dataloaders_dict['val']):
    print(inputs.shape)
    inputs2 = data_transforms['train'](inputs)

    print(inputs2.shape)

In [None]:
train(resnet, dataloaders_dict)