In [None]:
!pip install kornia
!pip install pafy youtube_dl

Collecting kornia
[?25l  Downloading https://files.pythonhosted.org/packages/ae/b2/8a968f1d7fb1d651a77c1ad7ffce9fc7b4dbd250eecaa9e2f21714fcfb2e/kornia-0.5.0-py2.py3-none-any.whl (271kB)
[K     |█▏                              | 10kB 15.9MB/s eta 0:00:01[K     |██▍                             | 20kB 12.0MB/s eta 0:00:01[K     |███▋                            | 30kB 9.0MB/s eta 0:00:01[K     |████▉                           | 40kB 8.2MB/s eta 0:00:01[K     |██████                          | 51kB 5.4MB/s eta 0:00:01[K     |███████▎                        | 61kB 5.8MB/s eta 0:00:01[K     |████████▌                       | 71kB 6.1MB/s eta 0:00:01[K     |█████████▋                      | 81kB 6.7MB/s eta 0:00:01[K     |██████████▉                     | 92kB 6.5MB/s eta 0:00:01[K     |████████████                    | 102kB 6.6MB/s eta 0:00:01[K     |█████████████▎                  | 112kB 6.6MB/s eta 0:00:01[K     |██████████████▌                 | 122kB 6.6MB/s eta

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from torchvision.models.resnet import ResNet, BasicBlock
from torch.autograd import Variable
import torch.optim as optim
from torch.utils import model_zoo
import torch.utils.data as data
from torch.utils.data import random_split

import numpy as np
from math import exp, ceil
import os
import os.path
import gzip
from six.moves import urllib
import time

import matplotlib.pyplot as plt
import matplotlib.animation as animation
from mpl_toolkits.axes_grid1 import ImageGrid
from PIL import Image, ImageOps

import pafy
import cv2 as cv
import random
from typing import Dict, Tuple, Optional

from kornia.filters import get_gaussian_kernel2d, filter2D

In [None]:
def seed_all(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)

seed_all(42)

In [None]:
MOUNT_GDRIVE = True # Use GoogleDrive or the current dir as working dir
TRAIN_MODE = False # Train or Test mode
TEST_EPOCH = 160 # Use saved model at given epoch for testing

PROG_IMG_RESIZING = True # Progressive image resizing mode for training
TRAIN_STEP = 1 # Step: 1, 2, 3, [4]

RANDOMIZE = True

In [None]:
config = {
    'learning_rate' : 1e-4,
    'weight_decay' : 1e-5,
    'optimizer': optim.Adam,
    'dataset_name' : 'Robot', # 'Mnist' or 'Robot'
    'n_frames' : 3, # number of input/target frames
    'mse_gain' : 0.3, # mse_loss coefficient in range [0, 1]
    'ssim_win_size' : 11,
    'n_epochs' : 100, # number of epochs to train in a single train step
    'save_step' : 10, # save the model after each speacified epochs
    'start_epoch': 1, # epoch up to which checkpoint model has been trained
    'loc_aware': True
}

In [None]:
def init_config():
  if config['dataset_name'] == 'Mnist':
    config['in_size'] = (64, 64)
    config['batch_size'] = 64
    
    config['n_workers'] = 2
    config['train_split'] = 9000
    
    config['mse_gain'] = 0.35
  else:
    assert config['dataset_name'] == 'Robot'
    config['in_size'] = (224, 320)
    config['batch_size'] = 4
    
    config['n_workers'] = 0
    config['train_split'] = None

    config['mse_gain'] = 0.1

In [None]:
def adjust_config(train_step):
  x = 32
  if config['dataset_name'] == 'Mnist':
    config['n_epochs_list'] = [50, 50, 70]
    in_size_dict = {
        1: (x*2, x*2),
        2: (x*4, x*4),
        3: (x*7, x*7)
    }
    bs_dict = {
        1: 64,
        2: 16,
        3: 8
    }
  else:
    assert config['dataset_name'] == 'Robot'
    config['n_epochs_list'] = [40, 40, 40, 40]

    in_size_dict = {
        1: (x*2, x*3),
        2: (x*4, x*6),
        3: (x*5, x*8),
        4: (x*7, x*10)
    }
    bs_dict = {
        1: 64,
        2: 16,
        3: 8,
        4: 4
    }
    if train_step != len(config['n_epochs_list']):
      config['learning_rate'] = 5*1e-4
    else:
      config['learning_rate'] = 1e-4

  config['in_size'] = in_size_dict[train_step]
  config['batch_size'] = bs_dict[train_step]
  config['n_epochs'] = config['n_epochs_list'][train_step-1]
  assert config['n_epochs'] % config['save_step'] == 0
  config['start_epoch'] = (sum(config['n_epochs_list'][:train_step-1]) + 1) if train_step > 1 else 1

  # add LocationAwareConv only at the last step
  config['loc_aware'] = (train_step == len(config['n_epochs_list']))

In [None]:
init_config()
if TRAIN_MODE and PROG_IMG_RESIZING:
  adjust_config(TRAIN_STEP)

In [None]:
root_dir = ''
if MOUNT_GDRIVE:
  from google.colab import drive
  root_dir = '/content/Drive/'
  drive.mount(root_dir)
  root_dir += 'MyDrive/VideoPredictionProject/'

Mounted at /content/Drive/


In [None]:
def create_dir(dir):
  os.makedirs(dir, exist_ok=True)

def create_dir_and_subdir(dir):
  dataset_name = config['dataset_name'].lower()
  subdir = os.path.join(dir, dataset_name)
  create_dir(dir)
  create_dir(subdir)
  return subdir

In [None]:
save_dir = create_dir_and_subdir(os.path.join(root_dir, 'models'))
data_dir = create_dir_and_subdir(os.path.join(root_dir, 'data'))
res_dir = create_dir_and_subdir(os.path.join(root_dir, 'results'))

In [None]:
if torch.cuda.is_available():
    avDev = torch.device("cuda")
else:
    avDev = torch.device("cpu")
print(avDev)

**Network implementation**

In [None]:
# Location Dependent Convolution
# Source: https://github.com/AIS-Bonn/LocDepVideoPrediction/blob/master/vlnOrig.ipynb

class LocationAwareConv2d(torch.nn.Conv2d):
  def __init__(self, locationAware, gradient, in_size, in_channels, out_channels,
               kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
    super().__init__(in_channels, out_channels, kernel_size, stride=stride,
                     padding=padding, dilation=dilation, groups=groups, bias=bias)
    h, w = in_size
    if locationAware:
      self.locationBias=torch.nn.Parameter(torch.zeros(h, w, 3))
      self.locationEncode=torch.autograd.Variable(torch.ones(h, w, 3))
      if gradient:
        for i in range(h):
          self.locationEncode[i,:,1] = self.locationEncode[:,i,0] = i/float(h-1)
        
    self.up = torch.nn.Upsample(size=(h, w), mode='bilinear', align_corners=False)
    self.h = h
    self.w = w
    self.locationAware=locationAware
    
  def forward(self, inputs):
    if self.locationAware:
      if self.locationBias.device != inputs.device:
        self.locationBias = self.locationBias.to(inputs.device)
      if self.locationEncode.device != inputs.device:
        self.locationEncode = self.locationEncode.to(inputs.device)
      b = self.locationBias*self.locationEncode
    convRes = super().forward(inputs)
    if convRes.shape[2] != self.h and convRes.shape[3] != self.w:
      convRes = self.up(convRes)
    if self.locationAware:
      return convRes + b[:,:,0] + b[:,:,1] + b[:,:,2]
    else:
      return convRes

In [None]:
# Convolutional GRU
# Source: https://github.com/happyjin/ConvGRU-pytorch/blob/master/convGRU.py

class ConvGRUCell(nn.Module):
  def __init__(self, in_channels, hid_channels, kernel_size, bias=True):
    super(ConvGRUCell, self).__init__()
    self.padding = kernel_size[0] // 2, kernel_size[1] // 2
    self.hid_channels = hid_channels
    self.bias = bias
    self.h_cur = None
  
    self.conv_gates = nn.Conv2d(in_channels=in_channels + hid_channels,
                                out_channels=2*self.hid_channels,  # for update_gate,reset_gate respectively
                                kernel_size=kernel_size,
                                padding=self.padding,
                                bias=self.bias)

    self.conv_can = nn.Conv2d(in_channels=in_channels + hid_channels,
                              out_channels=self.hid_channels, # for candidate neural memory
                              kernel_size=kernel_size,
                              padding=self.padding,
                              bias=self.bias)

  def forward(self, input):
    assert self.h_cur is not None
    self.h_cur = self.h_cur.to(input.device)

    combined = torch.cat([input, self.h_cur], dim=1)
    combined_conv = self.conv_gates(combined)

    gamma, beta = torch.split(combined_conv, self.hid_channels, dim=1)
    reset_gate = torch.sigmoid(gamma)
    update_gate = torch.sigmoid(beta)
        
    combined = torch.cat([input, reset_gate*self.h_cur], dim=1)
    cc_cnm = self.conv_can(combined)
    cnm = torch.tanh(cc_cnm)

    self.h_cur = (1 - update_gate) * self.h_cur + update_gate * cnm
    return self.h_cur

  def init_hidden(self, batch_size, in_size):
    self.h_cur = torch.zeros(batch_size, self.hid_channels, in_size[0], in_size[1])

In [None]:
def Conv1x1(in_channels: int, out_channels: int, stride: int=1):
  return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=True)

def Conv3x3(in_channels: int, out_channels: int):
  return nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=True)
  
def ConvTrans2x2(in_channels: int, out_channels: int):
  return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, bias=True)

def ReLU_BN(in_channels: int):
  return nn.Sequential(nn.ReLU(), nn.BatchNorm2d(in_channels))

def Conv_PixelShuffle_Conv(in_channels: int, out_channels: int):
  hid_channels = 1024
  upscale_factor = 2
  return nn.Sequential(
            Conv3x3(in_channels, hid_channels),
            nn.PixelShuffle(upscale_factor),
            Conv3x3(hid_channels//(upscale_factor**2), out_channels)
         )

In [None]:
class BridgeBlock(nn.Module):
  """
  Recurrent residual block between encoder and decoder layers.
  """
  def __init__(self, in_size: int, in_channels: int,
               hid_channels: int, loc_aware: bool=False):
    """
    param in_size: size of the input image frame.
    param in_channels: number of input channels.
    param hid_channels: number of hidden channels.
    param loc_aware: specifies whether to use location-dependent conv. or not
    """
    super(BridgeBlock, self).__init__()
    self._in_size = in_size
    self._in_channels = in_channels
    self._hid_channels = hid_channels
        
    if loc_aware:
      self.conv1x1 = LocationAwareConv2d(True, True, in_size, in_channels, hid_channels, 1)
    else:
      self.conv1x1 = Conv1x1(in_channels, hid_channels)

    self.conv_gru_1 = ConvGRUCell(in_channels, hid_channels, kernel_size=(3,3))
    self.conv_gru_2 = ConvGRUCell(in_channels, hid_channels, kernel_size=(5, 5))
    self.conv_gru_3 = ConvGRUCell(in_channels, hid_channels, kernel_size=(7, 7))
    self.gru_cells = [self.conv_gru_1, self.conv_gru_2, self.conv_gru_3]
    
  def forward(self, input: torch.Tensor):
    x1 = self.conv1x1(input)
    x2 = self.conv_gru_1(input)
    x3 = self.conv_gru_2(input)
    x4 = self.conv_gru_3(input)
    out = torch.cat([x1, x2, x3, x4], dim=1)
    return out
  
  def init_hidden(self, batch_size: int):
    """
    Initialize the hidden states of the convGRU blocks.
    """
    for gru_cell in self.gru_cells:
      gru_cell.init_hidden(batch_size, self._in_size)

  def insert_loc_dep_conv(self):
    """
    Insert a 1x1 LocationAwareConv2d layer in the block.
    """
    self.conv1x1 = LocationAwareConv2d(True, True, self._in_size, self._in_channels, self._hid_channels, 1)

In [None]:
class NextFrameNet(ResNet):
  """
  Outputs the next frame conditioned on the previous frames.
  """
  def __init__(self, in_size: Tuple[int, int]=(224, 224),
               loc_aware: bool=False, freeze_encoder: bool=True):
    """
    param in_size: input image frame size
    param loc_aware: specifies whether to use location-dependent conv. or not
    param freeze_encoder: specifies whether to freeze the ResNet-18 layers or not
    """
    super(NextFrameNet, self).__init__(BasicBlock, [2,2,2,2])
    
    # Encoder: a pretrained Resnet-18 network without the last GAP and FC layers
    resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
    state_dict = model_zoo.load_url(resnet18_url)
    self.load_state_dict(state_dict)
    del self.avgpool
    del self.fc

    if freeze_encoder:
      for _, child in self.named_children():
        for _, params in child.named_parameters():
          params.requires_grad = False

    h, w = in_size
    in_dim = h_dim = 64
    
    # Residual recurrent blocks between Encoder and Decoder layers
    self.bridge_block1 = BridgeBlock((h//2, w//2), in_dim, h_dim)
    self.bridge_block2 = BridgeBlock((h//4, w//4), in_dim, h_dim)
    self.bridge_block3 = BridgeBlock((h//8, w//8), in_dim*2, h_dim, loc_aware)
    self.bridge_block4 = BridgeBlock((h//16, w//16), in_dim*4, h_dim, loc_aware)
    
    self.bridge_blocks = [self.bridge_block1, self.bridge_block2,
                          self.bridge_block3, self.bridge_block4]
    
    # Decoder: uses PixelShuffle for upsampling
    self.relu_1 = nn.ReLU()
    self.pixel_shuffle_block1 = Conv_PixelShuffle_Conv(h_dim*8, h_dim)

    self.relu_bn_2 = ReLU_BN(h_dim*5)
    self.pixel_shuffle_block2 = Conv_PixelShuffle_Conv(h_dim*5, h_dim)

    self.relu_bn_3 = ReLU_BN(h_dim*5)
    self.pixel_shuffle_block3 = Conv_PixelShuffle_Conv(h_dim*5, h_dim) 

    self.relu_bn_4 = ReLU_BN(h_dim*5)
    self.pixel_shuffle_block4 = Conv_PixelShuffle_Conv(h_dim*5, h_dim)

    self.conv_trans_2x2 = ConvTrans2x2(h_dim*5, 3)

  def forward(self, input: torch.Tensor):
    x0 = self.conv1(input)
    x = self.bn1(x0)
    x = self.relu(x)
    x = self.maxpool(x)
        
    x1 = self.layer1(x)
    x2 = self.layer2(x1)
    x3 = self.layer3(x2)
    x4 = self.layer4(x3)

    x0_1 = self.bridge_block1(x0)
    x1_1 = self.bridge_block2(x1)
    x2_1 = self.bridge_block3(x2)
    x3_1 = self.bridge_block4(x3)

    x4 = self.relu_1(x4)
    x4 = self.pixel_shuffle_block1(x4)

    x4_1 = torch.cat([x3_1, x4], dim=1)
    x4_1 = self.relu_bn_2(x4_1)
    x4_1 = self.pixel_shuffle_block2(x4_1)
    
    x4_2 = torch.cat([x2_1, x4_1], dim=1)
    x4_2 = self.relu_bn_3(x4_2)
    x4_2 = self.pixel_shuffle_block3(x4_2)
  
    x4_3 = torch.cat([x1_1, x4_2], dim=1)
    x4_3 = self.relu_bn_4(x4_3)
    x4_3 = self.pixel_shuffle_block4(x4_3)

    x4_4 = torch.cat([x0_1, x4_3], dim=1)

    out = self.conv_trans_2x2(x4_4)
    return out

  def init_hidden(self, batch_size: int):
    """
    Initialize the hidden states of the convGRU blocks.
    """
    for bridge_block in self.bridge_blocks:
      bridge_block.init_hidden(batch_size)

  def insert_loc_dep_conv(self):
    """
    Replace simple 1x1 conv. layers by location aware conv.
    (after loading pretrained model).
    """
    self.bridge_block3.insert_loc_dep_conv()
    self.bridge_block4.insert_loc_dep_conv()

In [None]:
class AutoregressiveModel(nn.Module):
  """
  Takes a sequence of frames and runs through the model.
  Each of the next few frames are predicted based on the currently predicted frame.
  """
  
  def __init__(self, in_size: Tuple[int, int]=(224, 224),
               loc_aware: bool=False, freeze_encoder: bool=True):
    """
    param in_size: input image frame size
    param loc_aware: specifies whether to use location-dependent conv. or not
    param freeze_encoder: specifies whether to freeze the ResNet-18 layers or not
    """
    super(AutoregressiveModel, self).__init__()
    self.net = NextFrameNet(in_size, loc_aware, freeze_encoder)

  def forward(self, input: torch.Tensor):
    """
    :param input: input sequence of frames (BxFxCxHxW).
    """
    self.net.init_hidden(batch_size=input.size(0))

    n_frames = input.size(1)
    for i in range(n_frames):
      x = self.net(input[:,i])
    res = [x]

    for i in range(1, n_frames):
      x = self.net(x)
      res.append(x)

    out = torch.stack(res, dim=1)
    return out

  def insert_loc_dep_conv(self):
    """
    Replace simple 1x1 conv. layers by location aware conv.
    (after loading pretrained model).
    """
    self.net.insert_loc_dep_conv()

**Dataset**

In [None]:
class MovingMnist(data.Dataset):
  """
  Custom Dataset class for downloading and working with MovingMnist dataset.
  """
  def __init__(self, data_dir: str, download: bool=True, train: bool=True,
               train_split: int=9000, transform: Optional[transforms.Compose]=None,
               convert_rgb: bool=True, n_frames: int=3):
    """
    param data_dir: directory path to store data files.
    param download: specifies whether to download the .gz file or use the existing one.
    param train: train or test dataset.
    param train_split: number of samples to be used for training.
    param transform: transformation to be applied to each frame.
    param convert_rgb: specifies whether to convert the frames to RGB by duplicating the channel.
    param n_frames: number of input/target frames in each sample.
    """
    self.train = train
    self.transform = transform
    self.convert_rgb = convert_rgb
    self.n_frames = n_frames

    npy_file_name = 'MovingMnist.npy'
    npy_file = os.path.join(data_dir, npy_file_name)
    if download:
      url = 'https://github.com/tychovdo/MovingMNIST/raw/master/mnist_test_seq.npy.gz'
      gz_file = self.__download(url, data_dir)
      # unzip the downloaded .gz data file
      with open(npy_file, 'wb') as out_f:
        with gzip.GzipFile(gz_file) as zip_f:
          out_f.write(zip_f.read())
    else:
      assert os.path.isfile(npy_file), f'{npy_file} data file not found!'
    
    data_npy = np.load(npy_file)
    data_npy = data_npy.swapaxes(0, 1)

    assert train_split != 0
    if self.train:
      self.train_data = torch.from_numpy(data_npy[:train_split])
    else:
      self.test_data = torch.from_numpy(data_npy[train_split:])

  def __len__(self) -> int:
    """
    Returns the dataset length (train or test).
    """
    if self.train:
      return len(self.train_data)
    else:
      return len(self.test_data)

  def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Returns a tuple of input and traget frame sequences at the given sample index.
    """
    n_frames = self.n_frames
    if self.train:
      fidx = random.randint(0, self.train_data.size(1)-2*n_frames) if RANDOMIZE else 0
      input, target = self.train_data[index, fidx:fidx+n_frames], self.train_data[index, fidx+n_frames:fidx+2*n_frames]
    else:
      input, target = self.test_data[index, :n_frames], self.test_data[index, n_frames:2*n_frames]

    input = self.__transform(input)
    target = self.__transform(target)
    return input, target

  def __transform(self, data: torch.Tensor) -> torch.Tensor:
    """
    Transforms the given sequence of frames.
    """
    frame_seq = []
    for i in range(data.size(0)):
      frame = data[i].numpy()
      frame = Image.fromarray(frame)
      if self.convert_rgb:
        frame = frame.convert('RGB')
      if self.transform is not None:
        frame = self.transform(frame)
      else:
        frame = transforms.ToTensor()(frame)
      frame_seq.append(frame)
    frame_seq = torch.stack(frame_seq, dim=0)
    return frame_seq

  def __download(self, url: str, data_dir: str) -> os.path:
    """
    Downloads the data file from the specified url and stores in the data_dir.
    """
    filename = url.rpartition('/')[2]
    file_path = os.path.join(data_dir, filename)
    if os.path.exists(file_path):
      return file_path
    create_dir(data_dir)
    data = urllib.request.urlopen(url)
    with open(file_path, 'wb') as f:
      f.write(data.read())
    return file_path

In [None]:
class UnNormalize(object):
  def __init__(self, mean, std):
    self.mean = mean
    self.std = std

  @torch.no_grad()
  def __call__(self, tensor):
    for i in range(tensor.size(0)):
      for t, m, s in zip(tensor[i], self.mean, self.std):
        t.mul_(s).add_(m)
    return tensor

mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
transform = transforms.Compose([
                                  transforms.Resize(config['in_size']),
                                  transforms.ToTensor(),
                                  transforms.Normalize(mean, std)
                               ])
unorm = UnNormalize(mean, std)

In [None]:
def get_datasets(dataset_name):
  train_split = config['train_split']
  if dataset_name == 'Mnist':    
    train_data = MovingMnist(data_dir, download=True, train=True, train_split=train_split, transform=transform,
                             convert_rgb=True, n_frames=config['n_frames'])
    test_data = MovingMnist(data_dir, download=False, train=False, train_split=train_split, transform=transform,
                            convert_rgb=True, n_frames=config['n_frames'])
  else:
    assert dataset_name == 'Robot'
    train_skip_step = 7 # subsample to get around 10700 samples
    train_data = VideoFrameDataset(train_urls, train_ivals, data_dir=data_dir, train=True,
                                   transform=transform, sample_step=2, skip_step=train_skip_step, n_frames=config['n_frames'])
    
    test_skip_step = 3 # subsample to get around 1700 samples
    test_data = VideoFrameDataset(test_urls, test_ivals, data_dir=data_dir, train=False,
                                  transform=transform, sample_step=2, skip_step=test_skip_step, n_frames=config['n_frames'])
  return train_data, test_data

In [None]:
start_time = time.time()
train_data, test_data = get_datasets(config['dataset_name'])
end_time = time.time()
print('Time elapsed: {} mins'.format(round((end_time-start_time)/60, 1)))

dataset_size = len(train_data)
print('Train dataset size:', dataset_size)
print('Test dataset size:', len(test_data))

trainset_size = int(dataset_size*0.8)
valset_size = dataset_size - trainset_size
trainset, valset = random_split(train_data, [trainset_size, valset_size])

train_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=config['batch_size'],
                                           shuffle=True, num_workers=config['n_workers'])
val_loader = torch.utils.data.DataLoader(dataset=valset, batch_size=config['batch_size'],
                                         shuffle=False, num_workers=config['n_workers'])

dataloaders = {
    'train': train_loader,
    'val': val_loader
}

**Loss functions**

In [None]:
# Source: https://kornia.readthedocs.io/en/latest/_modules/kornia/losses/ssim.html

def ssim(img1: torch.Tensor, img2: torch.Tensor, window_size: int,
         max_val: float = 1.0, eps: float = 1e-12) -> torch.Tensor:

    if not isinstance(img1, torch.Tensor):
        raise TypeError("Input img1 type is not a torch.Tensor. Got {}"
                        .format(type(img1)))

    if not isinstance(img2, torch.Tensor):
        raise TypeError("Input img2 type is not a torch.Tensor. Got {}"
                        .format(type(img2)))

    if not isinstance(max_val, float):
        raise TypeError(f"Input max_val type is not a float. Got {type(max_val)}")

    if not len(img1.shape) == 4:
        raise ValueError("Invalid img1 shape, we expect BxCxHxW. Got: {}"
                         .format(img1.shape))

    if not len(img2.shape) == 4:
        raise ValueError("Invalid img2 shape, we expect BxCxHxW. Got: {}"
                         .format(img2.shape))

    if not img1.shape == img2.shape:
        raise ValueError("img1 and img2 shapes must be the same. Got: {} and {}"
                         .format(img1.shape, img2.shape))

    # prepare kernel
    kernel: torch.Tensor = (
        get_gaussian_kernel2d((window_size, window_size), (1.5, 1.5)).unsqueeze(0)
    )

    # compute coefficients
    C1: float = (0.01 * max_val) ** 2
    C2: float = (0.03 * max_val) ** 2

    # compute local mean per channel
    mu1: torch.Tensor = filter2D(img1, kernel)
    mu2: torch.Tensor = filter2D(img2, kernel)

    mu1_sq = mu1 ** 2
    mu2_sq = mu2 ** 2
    mu1_mu2 = mu1 * mu2

    # compute local sigma per channel
    sigma1_sq = filter2D(img1 ** 2, kernel) - mu1_sq
    sigma2_sq = filter2D(img2 ** 2, kernel) - mu2_sq
    sigma12 = filter2D(img1 * img2, kernel) - mu1_mu2

    # compute the similarity index map
    num: torch.Tensor = (2. * mu1_mu2 + C1) * (2. * sigma12 + C2)
    den: torch.Tensor = (
        (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
    )

    return num / (den + eps)

def ssim_loss(img1: torch.Tensor, img2: torch.Tensor, window_size: int,
              max_val: float = 1.0, eps: float = 1e-12, reduction: str = 'mean') -> torch.Tensor:
    
    # compute the ssim map
    ssim_map: torch.Tensor = ssim(img1, img2, window_size, max_val, eps)

    # compute and reduce the loss
    loss = torch.clamp((1. - ssim_map) / 2, min=0, max=1)

    if reduction == "mean":
        loss = torch.mean(loss)
    elif reduction == "sum":
        loss = torch.sum(loss)
    elif reduction == "none":
        pass
    return loss

class SSIMLoss(nn.Module):
    def __init__(self, window_size: int, max_val: float = 1.0,
                 eps: float = 1e-12, reduction: str = 'mean') -> None:
        super(SSIMLoss, self).__init__()
        self.window_size: int = window_size
        self.max_val: float = max_val
        self.eps: float = eps
        self.reduction: str = reduction

    def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
        return ssim_loss(img1, img2, self.window_size, self.max_val, self.eps, self.reduction)

In [None]:
SSIMloss = SSIMLoss(window_size=config['ssim_win_size'])
MSEloss = nn.MSELoss()

def MSE_DSSIM_Loss(pred, target):
  """
  param pred: predicted sequence of frames of size BxFxCxHxW
  param target: target sequence of frames of size BxFxCxHxW
  """
  alpha = config['mse_gain']
  mse_loss = 0.0
  dssim_loss = 0.0
  for i in range(pred.size(1)):
    mse_loss += MSEloss(pred[:,i], target[:,i])
    dssim_loss += SSIMloss(pred[:,i], target[:,i])
  mse_loss /= pred.size(1)
  dssim_loss /= pred.size(1)
  loss = alpha * mse_loss + (1 - alpha) * dssim_loss
  return loss, mse_loss, dssim_loss

**Training**

In [None]:
def save_checkpoint(model, epoch):
  torch.save(model.state_dict(), f'{save_dir}/model_{epoch}.pth')
  
def load_from_checkpoint(model, epoch):
  model.load_state_dict(torch.load(f'{save_dir}/model_{epoch}.pth', map_location='cpu'))
  model.eval()

In [None]:
def plot_losses(train_losses, val_losses, start_epoch, end_epoch):
  x = np.linspace(start_epoch, end_epoch, end_epoch - start_epoch + 1)
  phase_names = ['train', 'val']
  losses = [train_losses, val_losses]
  for phase_name, loss in zip(phase_names, losses):
    plt.plot(x, loss, label=phase_name)

  ax = plt.gca()
  ax.xaxis.get_major_locator().set_params(integer=True)
  plt.xlabel("#epochs")
  plt.ylabel('Loss')
  plt.legend()
  
  plt.savefig(f'{res_dir}/losses_{end_epoch}.png')
  plt.show()
  plt.clf()

def plot_combined_losses():
  train_losses = []
  val_losses = []
  for i in range(1, len(config['n_epochs_list'])+1):
    start_epoch = sum(config['n_epochs_list'][:i-1]) + 1 if i > 1 else 1
    end_epoch = sum(config['n_epochs_list'][:i])
    train_losses.append(np.load(f'{res_dir}/train_{start_epoch}_{end_epoch}.npy'))
    val_losses.append(np.load(f'{res_dir}/val_{start_epoch}_{end_epoch}.npy'))
    
  train_losses = np.concatenate(train_losses)
  val_losses = np.concatenate(val_losses)
  plot_losses(train_losses, val_losses, 1, end_epoch)

In [None]:
def train_model(model, criterion, optimizer, scheduler, dataloaders, start_epoch):
  n_epochs = config['n_epochs'] + start_epoch - 1
  save_step = config['save_step']
  n_frames = config['n_frames']

  train_losses = []
  val_losses = []
  for epoch in range(start_epoch, n_epochs+1):
    start_time = time.time()
    print('Epoch {}/{}'.format(epoch, n_epochs))
    print('-' * 10)

    for phase in ['train', 'val']:
      if phase == 'train':
        model.train()
        RANDOMIZE = True
      else:
        model.eval()
        RANDOMIZE = False
          
      running_loss = running_mse = running_dssim = 0.0
      dataset_size = 0
      dataloader = dataloaders[phase]
      for batch1, batch2 in dataloader:
        batch1 = batch1.to(avDev)
        batch2 = batch2.to(avDev)
                
        optimizer.zero_grad()
        with torch.set_grad_enabled(phase == 'train'):
          preds = model(batch1.float())
          loss, mse, dssim = criterion(preds.float(), batch2.float())

          if phase == 'train':
            loss.backward()                      
            optimizer.step()

        running_loss += loss.item() * batch1.size(0)
        running_mse += mse.item() * batch1.size(0)
        running_dssim += dssim.item() * batch1.size(0)
        dataset_size += batch1.size(0)

      epoch_loss = running_loss/dataset_size
      print('{} Loss: {:.4f}'.format(phase, epoch_loss))
      print('MSE : {:.4f}'.format(running_mse/dataset_size))
      print('DSSIM: {:.4f}'.format(running_dssim/dataset_size))
            
      if phase == 'train':
        train_losses.append(epoch_loss)
      else:
        val_losses.append(epoch_loss)
    
    if scheduler is not None:
      scheduler.step()
              
    if epoch % save_step == 0:
      save_checkpoint(model, epoch)
      plot_losses(train_losses, val_losses, start_epoch, epoch)
    end_time = time.time()
    print('Time elapsed: {} mins'.format(round((end_time-start_time)/60, 1)))
  return train_losses, val_losses

In [None]:
# create/load the model

load_loc_aware = config['loc_aware']
in_size = config['in_size']
  
if TRAIN_MODE:
  load_epoch = config['start_epoch'] - 1
  if PROG_IMG_RESIZING:
    load_loc_aware = False
else:
  load_epoch = TEST_EPOCH

model = AutoregressiveModel(in_size, load_loc_aware, freeze_encoder=False)
if load_epoch > 0:
  state_dict = torch.load(f'{save_dir}/model_{load_epoch}.pth', map_location='cpu')
  load_from_checkpoint(model, load_epoch)
  if TRAIN_MODE and PROG_IMG_RESIZING and config['loc_aware']:
    model.insert_loc_dep_conv()
model = model.to(avDev)

In [None]:
if TRAIN_MODE:
  Optimizer = config['optimizer']
  optimizer = Optimizer(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
  train_losses, val_losses = train_model(model, MSE_DSSIM_Loss, optimizer, scheduler=None,
                                         dataloaders=dataloaders, start_epoch=config['start_epoch'])
  # plot final lurning curves
  start_epoch = config['start_epoch']
  end_epoch = config['n_epochs'] + start_epoch - 1
  
  plot_losses(train_losses, val_losses, start_epoch, end_epoch)

  np.save(f'{res_dir}/train_{start_epoch}_{end_epoch}', train_losses)
  np.save(f'{res_dir}/val_{start_epoch}_{end_epoch}', val_losses)

  if PROG_IMG_RESIZING and TRAIN_STEP == len(config['n_epochs_list']):
    plot_combined_losses()

**Testing**

In [None]:
# Helper functions for reporting the performance metrics
@torch.no_grad()
def test_metrics(model, data_loader):
  model.eval()
  running_mse = running_dssim = 0.0
  dataset_size = 0.0
  for input, target in data_loader:
    input = input.to(avDev)
    target = target.to(avDev)    
    preds = model(input.float())
    _, mse, dssim = MSE_DSSIM_Loss(preds, target)
    running_mse += mse.item() * input.size(0)
    running_dssim += dssim.item() * input.size(0)
    dataset_size += input.size(0)

  mse = float(running_mse/dataset_size)
  dssim = float(running_dssim/dataset_size)
  return mse, dssim

def save_test_metrics(model, data_loader, dataset_name, fname, mode='w'):
  mse, dssim = test_metrics(model, data_loader)
  mse = round(mse, 4)
  dssim = round(dssim, 4)

  f_test_metrics = open(fname, mode)
  f_test_metrics.write(f'MSE on {dataset_name} data:{mse}\n')
  f_test_metrics.write(f'DSSIM on {dataset_name} data:{dssim}\n\n')
  f_test_metrics.close()

  print(f'MSE on {dataset_name} data:', mse)
  print(f'DSSIM on {dataset_name} data:', dssim)

In [None]:
# Helper functions for plotting the predicted frames
def save_animation(seq1, seq2, fname='anim.gif'):
  assert seq1.size() == seq2.size(), "Input and target sizes must match!"
  n_frames = seq1.size(0)
  n_channels = seq1.size(1)
  to_pil_img = transforms.ToPILImage()
  fig = plt.figure()
  ims = []

  for i in range(n_frames):
    frame = seq1[i].cpu()
    plt.axis('off')
    frame = to_pil_img(frame)
    if n_channels != 3:
      frame = frame.convert('RGB')
    frame_with_border = ImageOps.expand(frame, border=5, fill='red')
    im = plt.imshow(frame_with_border)
    ims.append([im])

  for i in range(n_frames):
    frame = seq2[i].cpu()
    plt.axis('off')
    frame = to_pil_img(frame)
    if n_channels != 3:
      frame = frame.convert('RGB')
    frame_with_border = ImageOps.expand(frame, border=5, fill='green')
    im = plt.imshow(frame_with_border)
    ims.append([im])

  anim = animation.ArtistAnimation(fig, ims, interval=300, repeat_delay=1000)
  anim.save(fname, writer=animation.PillowWriter(fps=60))


def save_grid(frame_list, fname):
  to_pil_img = transforms.ToPILImage()
  n_frames = len(frame_list)
  fig_h = 10
  fig_w = n_frames * fig_h * (frame_list[0].shape[2] / frame_list[0].shape[1])
  fig = plt.figure(figsize=(fig_w, fig_h))
  grid = ImageGrid(fig, 111, nrows_ncols=(1, n_frames), axes_pad=0.1)
  for ax, img in zip(grid, frame_list):
    ax.axis('off')
    ax.imshow(to_pil_img(img))
  plt.savefig(fname)
  plt.show()

@torch.no_grad()
def save_predictions(model, input, target, dataset_name, id):
  model.eval()
  preds = model(input.float())
  clip1 = input[0].clone()
  clip2 = target[0].clone()
  pred2 = preds[0]
  clip1 = unorm(clip1)
  clip2 = unorm(clip2)
  pred2 = unorm(pred2)
  # post-processing
  pred2 = pred2.clamp(0, 1)

  save_animation(clip1, clip2, fname=f'{res_dir}/original_{dataset_name}_{id}.gif')
  save_animation(clip1, pred2, fname=f'{res_dir}/predicted_{dataset_name}_{id}.gif')

  save_grid(torch.cat([clip1.cpu(), clip2.cpu()], dim=0), fname=f'{res_dir}/original_{dataset_name}_{id}.png')
  save_grid(torch.cat([clip1.cpu(), pred2.cpu()], dim=0), fname=f'{res_dir}/predicted_{dataset_name}_{id}.png')

In [None]:
# save generations on train-data
RANDOMIZE = False
f_test_metrics = f'{res_dir}/test_metrics.txt'
save_test_metrics(model, train_loader, 'train', fname=f_test_metrics, mode='w')
save_test_metrics(model, val_loader, 'val', fname=f_test_metrics, mode='a')

batch1, batch2 = next(iter(train_loader))
batch1 = batch1.to(avDev)
batch2 = batch2.to(avDev)
for i in range(min(10, batch1.size(0))):
  save_predictions(model, batch1[i:i+1], batch2[i:i+1], 'train', i)

In [None]:
# save generations on test-data
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=config['batch_size'],
                                          shuffle=True, num_workers=config['n_workers'])
save_test_metrics(model, test_loader, 'test', fname=f_test_metrics, mode='a')

batch1, batch2 = next(iter(test_loader))
batch1 = batch1.to(avDev)
batch2 = batch2.to(avDev)
for i in range(min(10, batch1.size(0))):
  save_predictions(model, batch1[i:i+1], batch2[i:i+1], 'test', i)