In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
import torch
import cv2
import os
import collections
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms.functional as TF
import torchvision.transforms

device = torch.device("cuda:0")

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enable = True

generate = True
frameFolderPath = './video_frames/'

videos = os.listdir('drive/My Drive/training_videos')

if generate:
    count = 0
    for video in videos:
        vidcap = cv2.VideoCapture('drive/My Drive/training_videos/' + video)
        success,image = vidcap.read()
        print(success)
        while success:
            cv2.imwrite(frameFolderPath + "frame%d.png" % count, image)      
            success,image = vidcap.read()
            count += 1   

In [0]:
count = 0
vidcap = cv2.VideoCapture('drive/My Drive/training_videos/all_test11346.avi')
success,image = vidcap.read()
print(success)
while success:
    cv2.imwrite('test_frames/frame%d.png' % count, image)      
    success,image = vidcap.read()
    count += 1  

In [0]:
class FrameLoader(Dataset):
    def __init__(self, root_dir, transform = None):
        self.root_dir = root_dir
        self.transform = transform
        self.frame_list = os.listdir(root_dir)
        self.frame_list.sort(key = lambda x: os.stat(root_dir + x).st_ctime)

    def __len__(self):
        return len(self.frame_list) // 3 - 1
    
    def __getitem__(self, idx):
        frame1 = Image.open(self.root_dir + self.frame_list[3 * idx])
        frame2 = Image.open(self.root_dir + self.frame_list[3 * idx + 2])
        intermediate = Image.open(self.root_dir + self.frame_list[3 * idx + 1])
        
        if self.transform is not None:
            frame1 = self.transform(frame1)
            frame2 = self.transform(frame2)
            intermediate = TF.resize(intermediate, (256, 448))
            intermediate = TF.to_tensor(intermediate)
        else:
            frame1 = TF.to_tensor(frame1)
            frame2 = TF.to_tensor(frame2)
            intermediate = TF.to_tensor(intermediate)
        
            
        return [torch.cat((frame1, frame2)), intermediate]

In [0]:
batch_size = 4

dataset = FrameLoader('./video_frames/')
testset = FrameLoader('./test_frames/')
data_loader = DataLoader(dataset, batch_size = batch_size, num_workers = 4, shuffle = True, pin_memory = True)
test_loader = DataLoader(testset, batch_size = batch_size, num_workers = 4, shuffle = True, pin_memory = True)

In [0]:
import torch.nn as nn
import torch.nn.functional as F

def convolution_pack(channels_in, channels_out, kernel_size, padding_size):
    return nn.Sequential(nn.Conv2d(channels_in, channels_out, kernel_size, padding = padding_size),
                       nn.ReLU(),
                       nn.Conv2d(channels_out, channels_out, kernel_size, padding = padding_size),
                       nn.ReLU(),
                       nn.Conv2d(channels_out, channels_out, kernel_size, padding = padding_size),
                       nn.ReLU(),
                       nn.Conv2d(channels_out, channels_out, kernel_size, padding = padding_size),
                       nn.ReLU(),
                       nn.Conv2d(channels_out, channels_out, 2, stride = 2),
                       nn.ReLU()
                       )

def transposed_convolution(channels_in, channels_out, kernel_size, padding_size):
    return nn.Sequential(nn.Upsample(scale_factor=2, mode = 'nearest'),
                       nn.ReLU(),
                       nn.Conv2d(channels_in, channels_out, kernel_size, padding = padding_size),
                       nn.ReLU(),
                       nn.Conv2d(channels_out, channels_out, kernel_size, padding = padding_size),
                       nn.ReLU(),
                       nn.Conv2d(channels_out, channels_out, kernel_size, padding = padding_size),
                       nn.ReLU(),
                       nn.Conv2d(channels_out, channels_out, kernel_size, padding = padding_size),
                       nn.ReLU()
                       )

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.pack1 = convolution_pack(6, 32, 3, 1)
        self.pack2 = convolution_pack(32, 64, 3, 1)
        self.pack3 = convolution_pack(64, 128, 3, 1)
        self.pack4 = convolution_pack(128, 256, 3, 1)
        self.pack5 = convolution_pack(256, 512, 3, 1)
        self.transpose2 = transposed_convolution(512, 256, 3, 1)
        self.transpose3 = transposed_convolution(256, 128, 3, 1)
        self.transpose4 = transposed_convolution(128, 64, 3, 1)
        self.transpose5 = transposed_convolution(64, 32, 3, 1)
        self.transpose6 = transposed_convolution(32, 3, 3, 1)

        self.initialize_weights()

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)

    def forward(self, x):
        pack1 = self.pack1(x)
        pack2 = self.pack2(pack1)
        pack3 = self.pack3(pack2)
        pack4 = self.pack4(pack3)
        pack5 = self.pack5(pack4)
        x = self.transpose2(pack5)
        x = self.transpose3(x + pack4)
        x = self.transpose4(x + pack3)
        x = self.transpose5(x + pack2)
        x = self.transpose6(x + pack1)
            
        return x
net = Net().to(device)

In [0]:
import torch.optim as optim
from tqdm import tqdm
from pytorch_msssim import MS_SSIM

ms_ssim_module = MS_SSIM(data_range=255, size_average=True, channel=3)

optimizer = optim.Adamax(net.parameters(), lr = 1e-4)
loss_function = nn.L1Loss()

epochs = 10
alpha = 0.84  

training_loss = []

for epoch in range(epochs): 
    for inputs, labels in tqdm(data_loader):
        inputs = inputs.to(device)
        labels = labels.to(device) 
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = alpha * (1 - ms_ssim_module(outputs, labels)) + (1 - alpha) * loss_function(outputs, labels)
        training_loss.append(loss.item())
        loss.backward()
        optimizer.step()
    print("Epoch: {} Loss: {}".format(epoch + 1, loss))

In [0]:
from SSIM_PIL import compare_ssim
from tqdm import tqdm
ssim_sum = 0
with torch.no_grad():
    for sample, target in tqdm(test_loader):
    output = net(sample.to(device))
    for i, frame in enumerate(output):
        image1 = torchvision.transforms.ToPILImage()(frame.cpu())
        image2 = torchvision.transforms.ToPILImage()(target[i].cpu())
        ssim_sum += compare_ssim(image1, image2)
    print("Average ssim score: {}".format(ssim_sum / (batch_size * len(test_loader))))

In [0]:
import numpy 
import math
PIXEL_MAX = 255.0
def psnr(img1, img2):
    img1 = numpy.array(img1)
    img2 = numpy.array(img1)
    mse = numpy.mean( (img1 - img2) ** 2 )
    if mse == 0:
        return 100
    return 20 * math.log10(PIXEL_MAX) - 10 * math.log10(mse)

psnr_sum = 0
with torch.no_grad():
    for sample, target in tqdm(test_loader):
        output = net(sample.to(device))
        for i, frame in enumerate(output):
            image1 = torchvision.transforms.ToPILImage()(frame.cpu())
            image2 = torchvision.transforms.ToPILImage()(target[i].cpu())
            psnr_sum += psnr(image1, image2)
print("Average psnr score: {}".format(psnr_sum / (batch_size * len(test_loader))))