In [1]:
import torch
import numpy as np
from torchvision import datasets
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import matplotlib.pyplot as plt
from torch.utils.data.sampler import SubsetRandomSampler

In [4]:
class UNet(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(UNet,self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        #UNet Arch:
        #1)Encoder
        self.conv1  = nn.Conv2d(in_channels,32,7,1,3)
        self.conv2  = nn.Conv2d(32,32,7,1,3)
        
        self.conv3  = nn.Conv2d(32,64,5,1,2)
        self.conv4  = nn.Conv2d(64,64,5,1,2)
        
        self.conv5  = nn.Conv2d(64,128,3,1,1)
        self.conv6  = nn.Conv2d(128,128,3,1,1)
        
        self.conv7  = nn.Conv2d(128,256,3,1,1)
        self.conv8  = nn.Conv2d(256,256,3,1,1)
        
        self.conv9  = nn.Conv2d(256,512,3,1,1)
        self.conv10 = nn.Conv2d(512,512,3,1,1)
        
        self.conv11 = nn.Conv2d(512,512,3,1,1)
        self.conv12 = nn.Conv2d(512,512,3,1,1)
        
        #1)Decoder
        self.up_sample = nn.up_sample(scale_factor = 2, mode='bilinear')
        
        self.conv13 = nn.conv2(512,512,3,1,1)
        self.conv14 = nn.conv2(2*512,512,3,1,1)
        
        self.conv15 = nn.conv2(512,256,3,1,1)
        self.conv16 = nn.conv2(2*256,256,3,1,1)
        
        self.conv17 = nn.conv2(256,128,3,1,1)
        self.conv18 = nn.conv2(2*128,128,3,1,1)
        
        self.conv19 = nn.conv2(128,64,3,1,1)
        self.conv20 = nn.conv2(2*64,64,3,1,1)
        
        self.conv21 = nn.conv2(64,32,3,1,1)
        self.conv22 = nn.conv2(2*32,out_channels,3,1,1)
        
    def forward(self,images):
        out1 = F.leaky_relu(self.conv2(self.conv1(images)),negative_slope=0.1)
        out2 = F.avg_pool2d(out1,2)
        out3 = F.leaky_relu(self.conv4(self.conv3(out2)),negative_slope=0.1)
        out4 = F.avg_pool2d(out3,2)
        out5 = F.leaky_relu(self.conv6(self.conv5(out4)),negative_slope=0.1)
        out6 = F.avg_pool2d(out5,2)
        out7 = F.leaky_relu(self.conv8(self.conv7(out6)),negative_slope=0.1)
        out8 = F.avg_pool2d(out7,2)
        out9 = F.leaky_relu(self.conv10(self.conv9(out8)),negative_slope=0.1)
        out10 = F.avg_pool2d(out9,2)
        out11 = F.leaky_relu(self.conv12(self.conv11(out10)),negative_slope=0.1)
        out12 = F.avg_pool2d(out11,2)
        
        out13 = self.up_sample(out12)
        out14 = F.leaky_relu(self.conv14(torch.cat((self.conv13(out13),out9),1)),negative_slope=0.1)
        
        out15 = self.up_sample(out14)
        out16 = F.leaky_relu(self.conv16(torch.cat((self.conv15(out15),out7),1)),negative_slope=0.1)
        
        out17 = self.up_sample(out16)
        out18 = F.leaky_relu(self.conv18(torch.cat((self.conv17(out16),out5),1)),negative_slope=0.1)
        
        out19 = self.up_sample(out14)
        out20 = F.leaky_relu(self.conv20(torch.cat((self.conv19(out18),out3),1)),negative_slope=0.1)
        
        out21 = self.up_sample(out14)
        out22 = F.leaky_relu(self.conv22(torch.cat((self.conv21(out20),out1),1)),negative_slope=0.1)
        return out22

# Network Details:
![Architecture](img/Arch.png)
## <font color='red' >Flow Computation Network:</font>
* U-Net Architecture (in_channels = 6, out_Channels = 4)
* input I0 , I1
* output F0->1 , F1->0
* taking two input images I0 and I1, to jointly predict the forward optical flow F0→1 and backward optical          flow F1→0 between them.
        
## <font color='red' >Arbitary-time flow interpolation:</font>
* U-Net Architecture (in_channels = 20, out_Channels = 5)
* inputs I1 , g(I1,Ft->1) , Ft->1, ft->0 , g(I0,Ft->0) , I0
* outputs I1 , Vt<-1 , ▲Ft->1 , ▲Ft->0 , Vt<-0 , I0

### I(t) is computed from Arbitart-time flow interpolation outputs

# <font color='red' >Loss Function:</font>

## <center><font color='blue' > L = λr lr + λp lp + λw lw + λs ls </font></center>
* lr: Reconstruction loss to model how good the reconstruction of the intermediate frames
* lp: Perceptual loss to preserve details of the predictions, and make interpolated frames sharper
* lw: Wraping loss to model quality of computed optical flow
* ls: Smoothness loss to encourage neighbboring pixels to have similir flow values
* λr = 0.8 , λp = 0.005 , λw = 0.4 , λs = 1 

    

In [2]:
recon_loss = nn.L1Loss()
percep_loss = nn.MSELoss()