In [1]:
import video_data_classes as vid
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.autograd.profiler as profiler
import torch
import h5py
import utils
from time import perf_counter
from skimage import io
from importlib import reload
from IPython.core.debugger import set_trace
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

%matplotlib inline

In [2]:
class SuperResNetwork(nn.Module):
    def __init__(self):
        super(SuperResNetwork, self).__init__()

        self.input_layer = nn.Conv2d(6, 8, (3, 3))
        self.output_layer = nn.ConvTranspose2d(8, 3, (3, 3))
        self.activation = nn.ReLU()
    
    def forward(self, x):
        first_input = x[0, :, :, :, :]
        sequence_length = x.shape[0]
        
        y = self.input_layer(torch.cat([first_input, torch.zeros_like(first_input)], dim=1)) # concatenate the first input with zeros in the channel dimension
        y = self.activation(y)
        y = self.output_layer(y)
        y = torch.sigmoid(y)
        output = torch.unsqueeze(y, 0)
        
        for i in range(1, sequence_length):
            y = self.input_layer(torch.cat([x[i, :, :, :, :], output[i - 1, :, :, :, :]], dim=1))
            y = self.activation(y)
            y = self.output_layer(y)
            y = torch.sigmoid(y)
            output = torch.cat([output, torch.unsqueeze(y, 0)], dim=0)
            
        return output

In [7]:
# Define the network, loss function, and optimizer

network = SuperResNetwork()
network.to('cuda:0')
criterion = nn.MSELoss()
optimizer = optim.Rprop(network.parameters())

In [2]:
# 'dataset' represents a dataset of short video samples

video_directory = r'C:\Users\John\PythonVenvs\VideoSuperResolution\Scripts\Raw Data\Raw, Half-Size, and PNGs'
dataset = vid.VideoDataset(video_directory, 10)

In [None]:
# Creates an hdf5 file of video data for fast i/o during training

utils.generate_hdf5(dataset)

196


In [13]:
with h5py.File(r'E:\hdf5_file.hdf5', 'r') as data:
    writer = SummaryWriter()

    train_batch_size = 4
    val_batch_size = 2
    total_batch_size = train_batch_size + val_batch_size
    batch_indices = utils.make_batches(len(dataset), total_batch_size)

    x = torch.empty(dataset.sequence_length, total_batch_size, 3, 64, 64, device='cuda:0')
    y = torch.empty(dataset.sequence_length, total_batch_size, 3, 256, 256, device='cuda:0')

    for i in list(range(batch_indices.shape[0])):
        x[...] = torch.tensor(data['X'][batch_indices[i, :], ...].transpose(1, 0, 2, 3, 4), device='cuda:0') / 255.0 # shape [dataset.sequence_length, total_batch_size, c, h, w]
        y[...] = torch.tensor(data['Y'][batch_indices[i, :], ...].transpose(1, 0, 2, 3, 4), device='cuda:0') / 255.0 # shape [dataset.sequence_length, total_batch_size, c, h, w]

        optimizer.zero_grad()
        output = network(x[:, :train_batch_size, ...])
        loss = criterion(output, y[:, :train_batch_size, ...])
        loss.backward()
        optimizer.step()

        val_output = network(x[:, train_batch_size:, ...])
        val_loss = criterion(val_output, y[:, train_batch_size:, ...])

        writer.add_scalar('Loss/Train', loss, i)
        writer.add_scalar('Loss/Validation', val_loss, i)



In [12]:
reload(utils)

<module 'utils' from 'C:\\Users\\John\\PythonVenvs\\VideoSuperResolution\\Scripts\\VideoSuperResolution\\utils.py'>