<a href="https://colab.research.google.com/github/BenjaminEngelman/Super-Resolution/blob/master/notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Imports

In [1]:
!pip install torchvision==0.4.0

from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch.utils.data.sampler import SubsetRandomSampler


# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



## Google Drive 

In [2]:
from google.colab import drive
import os

drive.mount('/content/gdrive', force_remount=True)
root_path = 'gdrive/My Drive/videoEnhancer'
os.chdir(root_path)  #change dir
!ls

Mounted at /content/gdrive
autoencoder.pt	frames	notebook.ipynb	video1080p.mp4


## Dataset

In [0]:
LOW_RES = (540, 960)

class FrameDataset(Dataset):
    """
    Dataset in which each element is composed of a frame in 1080p (original)
    And the identical frame at 540p (low_res)
    """

    def __init__(self, root_dir):
        """
        Args:
            root_dir (string): Directory with all the 1080p images.

        """
        self.root_dir = root_dir
        self.img_names = [name for name in os.listdir(self.root_dir)]
        self.transform = transforms.Compose([transforms.ToTensor(),
                                             transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                             ])

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.img_names[idx])
        original = io.imread(img_name)
        low_res = transform.resize(original, LOW_RES)
        sample = {'original': self.transform(original), 'low_res': self.transform(low_res)}

        return sample

## Autoencoder

### Model Definitions

In [0]:
# Inspired from: https://github.com/leaxp/Deep-Learning-Super-Resolution-Image-Reconstruction-DSIR/blob/master/conv_autoencoder.py

import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

# define the NN architecture
class ConvAutoencoder(nn.Module):
    """
    Convolutional Autoencoder. The goal is to start from an Image of low resolution 
    (960x540) and to transform it into the same image but in a higher resolution
    (1920x1080)
    """
    def __init__(self):
        super(ConvAutoencoder, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 8, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(8)
        self.conv3 = nn.Conv2d(8, 8, 3, padding=1)
        self.conv4 = nn.Conv2d(8, 3, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(3)
        self.convt1 = nn.ConvTranspose2d(8, 8, 2, stride=2)
        self.convt2 = nn.ConvTranspose2d(8, 16, 2, stride=2)
        self.convt3 = nn.ConvTranspose2d(16, 8, 2, stride=2)
        self.convt4 = nn.ConvTranspose2d(8, 16, 2, stride=2)

        self.pool = nn.MaxPool2d(2, 2)
        # Xavier initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal(m.weight.data)
                nn.init.normal(m.bias.data)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool(F.relu(self.bn1(x)))
        x = self.conv2(x)
        x = self.pool(F.relu(self.bn2(x))) 
       
        x = self.convt1(x)
        x = F.relu(self.bn2(x)) 
        x = self.convt2(x)
        x = F.relu(self.bn1(x)) 
        x = self.convt3(x)
        x = F.relu(self.bn2(x)) 
        x = self.conv4(x)
        x = F.relu(self.bn4(x))

        return x

# print(model)
# summary(model, (3, 540, 960))

### Training

#### Data setup

In [0]:
# https://stackoverflow.com/questions/50544730/how-do-i-split-a-custom-dataset-into-training-and-test-datasets

shuffle_dataset = True
dataset = FrameDataset('frames/')
validation_split = .2
random_seed= 42
batch_size = 16

# Creating data indices for training and validation splits:
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

dataloaders = {
    "train": DataLoader(dataset, batch_size=batch_size, sampler=train_sampler),
    "val": DataLoader(dataset, batch_size=batch_size, sampler=valid_sampler)
    }

#### Training Loop

In [0]:
CHECKPOINT_PATH = "model/autoencoder.pt"

# initialize the NN
model = ConvAutoencoder()
model = model.to(device)

# specify loss function
criterion = nn.MSELoss()

# specify loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# number of epochs to train the model
n_epochs = 50
best_val_loss = 999

for epoch in range(1, n_epochs+1):
    # monitor training loss
    losses = {'train': 0, 'val': 0}
    
    for phase in ['train', 'val']:
        running_loss = 0
        for i, data in enumerate(dataloaders[phase]):
            # clear thed gradients of all optimized variables
            optimizer.zero_grad()
            # forward pass: compute predicted outputs by passing inputs to the model
            inputs = data['low_res'].to(device=device, dtype=torch.float)
            originals = data['original'].to(device)
            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(inputs)
                # calculate the loss
                loss = criterion(outputs, originals)
                if phase == 'train':
                    # backward pass: compute gradient of the loss with respect to model parameters
                    loss.backward()
                    # perform a single optimization step (parameter update)
                    optimizer.step()
            # update running training loss
            running_loss  += loss.item()*inputs.size(0)
            # print("Batch %d / %d, loss: %.4f" % (i + 1, len(dataloader), loss))
        losses[phase]  = running_loss / len(dataloaders[phase])
    
    print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
        epoch, 
        losses['train'],
        losses['val']
        ))
    
    if losses['val'] < best_val_loss:
        best_val_loss = losses['val'] 
    
        torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': losses['train'],
                'val_loss': losses['val'],
                
                }, CHECKPOINT_PATH)