In [1]:
import torch
import torchvision
import numpy as np
from torchvision import datasets
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import matplotlib.pyplot as plt
from torch.utils.data.sampler import SubsetRandomSampler
import time
import math
import UNet
import data_loader

# Network Details:
![Architecture](img/Arch.png)
## <font color='red' >Flow Computation Network:</font>
* U-Net Architecture (in_channels = 6, out_Channels = 4)
* input I0 , I1
* output F0->1 , F1->0
* taking two input images I0 and I1, to jointly predict the forward optical flow F0→1 and backward optical          flow F1→0 between them.
        
## <font color='red' >Arbitary-time flow interpolation:</font>
* U-Net Architecture (in_channels = 20, out_Channels = 5)
* inputs I1 , g(I1,Ft->1) , Ft->1, ft->0 , g(I0,Ft->0) , I0
* outputs I1 , Vt<-1 , ▲Ft->1 , ▲Ft->0 , Vt<-0 , I0

### I(t) is computed from Arbitart-time flow interpolation outputs

# <font color='red' >Loss Function:</font>

## <center><font color='blue' > L = λr lr + λp lp + λw lw + λs ls </font></center>
* lr: Reconstruction loss to model how good the reconstruction of the intermediate frames
* lp: Perceptual loss to preserve details of the predictions, and make interpolated frames sharper
* lw: Wraping loss to model quality of computed optical flow
* ls: Smoothness loss to encourage neighbboring pixels to have similir flow values
* λr = 0.8 , λp = 0.005 , λw = 0.4 , λs = 1 

    

In [None]:
#Global Variables
train_set_path = ""
test_set_path = ""
validation_set_path = ""

image_seq_weights = np.linspace(0.125, 0.875, 7)

batch_size = 6
mean = [0.429, 0.431, 0.397]
std  = [1, 1, 1]
data_transform = transforms.Compose([    
        transforms.Resize((352,352)),
        transforms.ToTensor(),
        transforms.Normalize(mean,std)
    ])

train_dataset = data_loader.dataset_loader(train_set_path,data_transform)
validation_dataset = data_loader.dataset_loader(validation_set_path,data_transform)
test_dataset = data_loader.dataset_loader(test_set_path,data_transform)

train_dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
validation_dataloader = torch.utils.data.DataLoader(validation_dataset,batch_size=batch_size,shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=True)

In [None]:
flow_model = UNet.UNet(6,4)
arb_time_flow = UNet.UNet(20,5)

In [None]:
recon_loss = nn.L1Loss()
percep_loss = nn.MSELoss()
#loading Vgg16's conv_4_3 to use in loss calculation
vgg16_model = torchvision.models.vgg16(pretrained=True)
vgg16_conv_4_3 = nn.Sequential(*list(vgg16_model.children())[0][:22])
for parameter in vgg16_conv_4_3.parameters():
    parameter.requires_grad = False

# specify optimizer
learning_parameters = list(flow_model.parameters()) + list(arb_time_flow.parameters())
optimizer = optim.Adam(learning_parameters,lr=0.001)

In [None]:
def get_intermediate_flow(F0_1,F1_0,t):
    Ft_0 = (-(1-t)*t*F0_1) + (math.pow(t,2)  * F1_0)
    Ft_1 = ( math.pow(1-t,2) * F0_1) - ( (1-t) * t * F1_0)
    return Ft_0, Ft_1

def apply_flow(img,flow,device):
    """
        warp an image according to the optical flow
        img : [B, C, H, W] ex: torch.Size([3, 3, 352, 352])
        flow: [B, 2, H, W] ex: torch.Size([3, 2, 352, 352])
        
        F.grid_sample() : for input with shape (B, C, Hin, Win)
        and grid with shape (B, Hout, Wout, 2), 
        the output will have shape (B, C, Hout, Wout).
    """
    B,C,H,W = img.size()
    gridx,gridy = np.meshgrid(np.linspace(-1,1,H),np.linspace(-1,1,W))
    gridx = np.reshape(gridx,(1,1,H,W))
    gridy = np.reshape(gridy,(1,1,H,W))
    gridx = np.repeat(gridx,repeats=B,axis=0)
    gridy = np.repeat(gridy,repeats=B,axis=0)
    grid = np.concatenate((gridx,gridy),axis=1)
    grid = torch.from_numpy(grid).float().to(device)
    normalized_flow =  flow + grid
    normalized_flow = normalized_flow.view(B,H,W,2)
    new_img = F.grid_sample(img,normalized_flow,mode='bilinear')
    return new_img

def get_intermediate_image(t,Vt_0,Vt_1,g_I_0,g_I_1):
    Z = (1-t)*Vt_0 + tVt_1
    It = (1-t)*Vt_0 * g_I_0 
    It += t*Vt_1 * g_I_1 
    It *= (1/Z)
    return It

In [None]:
epochs = 5
valid_loss_min = np.inf
#training loop
for i in range(epochs):
    train_loss = 0
    validation_loss = 0
    model.train()
    start_time = time.time()
    for train_data, frame_index in train_dataloader:
        optimizer.zero_grad()
        
        I0 ,It ,I1 = train_data
        flow_out = flow_model(torch.cat((I0,I1),dim=1))
        f0_1 = flow_out[:,:2,:,:]
        f1_0 = flow_out[:,2:,:,:]
        
        ft_0_aprox ,ft_1_aprox = get_intermediate_flow(f0_1,f1_0,image_seq_weights[frame_index])
        
        g_I_0 = apply_flow(I0,ft_0_aprox,device)
        g_I_1 = apply_flow(I1,ft_1_aprox,device)
        
        arbitary_time_flow = arb_time_flow(torch.cat((I0, I1, f0_1, f1_0, ft_1_aprox, ft_0_aprox, g_I_1, g_I_0), dim=1))
        
        ft_0 = arbitary_time_flow[:,:2,:,:] + ft_0_aprox
        ft_1 = arbitary_time_flow[:,2:4,:,:]+ ft_1_aprox
        vt_0 = arbitary_time_flow[:,4:5,:,:]
        vt_1 = 1 - vt_0
        
        It_predicted = get_intermediate_image(image_seq_weights[frame_index],vt_0,vt_1,g_I_0,g_I_1)
        
        Lr = recon_loss(It,It_predicted)
        Lp = percep_loss(vgg16_conv_4_3(It_predicted),vgg16_conv_4_3(It))
        
        g_I1_f01 = apply_flow(I1,f0_1)
        g_I0_f10 = apply_flow(I0,f1_0)
        Lw = recon_loss(I0,g_I1_f01) + recon_loss(I1,g_I0_f10) + percep_loss(It,g_I_0) + percep_loss(It,g_I_1)
        
        smooth_loss1_0 = torch.mean(torch.abs(f1_0[:, :, :, :-1] - f1_0[:, :, :, 1:])) + torch.mean(torch.abs(f1_0[:, :, :-1, :] - f1_0[:, :, 1:, :]))
        smooth_loss0_1 = torch.mean(torch.abs(f0_1[:, :, :, :-1] - f0_1[:, :, :, 1:])) + torch.mean(torch.abs(f0_1[:, :, :-1, :] - f0_1[:, :, 1:, :]))
        Ls = smooth_loss1_0 + smooth_loss0_1
        
        L = 0.8*Lr + 0.005*Lp + 0.4*Lw + Ls
        # Backpropagate
        L.backward()
        optimizer.step()
        train_loss += L.item()
    model.eval()
    ##################################################################################################
    for train_data, frame_index in valid_loader:
        I0 ,It ,I1 = train_data
        flow_out = flow_model(torch.cat((I0,I1),dim=1))
        f0_1 = flow_out[:,:2,:,:]
        f1_0 = flow_out[:,2:,:,:]
        with torch.no_grad:
            ft_0_aprox ,ft_1_aprox = get_intermediate_flow(f0_1,f1_0,image_seq_weights[frame_index])
        
            g_I_0 = apply_flow(I0,ft_0_aprox,device)
            g_I_1 = apply_flow(I1,ft_1_aprox,device)
        
            arbitary_time_flow = arb_time_flow(torch.cat((I0, I1, f0_1, f1_0, ft_1_aprox, ft_0_aprox, g_I_1, g_I_0), dim=1))
        
            ft_0 = arbitary_time_flow[:,:2,:,:] + ft_0_aprox
            ft_1 = arbitary_time_flow[:,2:4,:,:]+ ft_1_aprox
            vt_0 = arbitary_time_flow[:,4:5,:,:]
            vt_1 = 1 - vt_0
        
            It_predicted = get_intermediate_image(image_seq_weights[frame_index],vt_0,vt_1,g_I_0,g_I_1)
        
            Lr = recon_loss(It,It_predicted)
            Lp = percep_loss(vgg16_conv_4_3(It_predicted),vgg16_conv_4_3(It))
        
            g_I1_f01 = apply_flow(I1,f0_1)
            g_I0_f10 = apply_flow(I0,f1_0)
            Lw = recon_loss(I0,g_I1_f01) + recon_loss(I1,g_I0_f10) + percep_loss(It,g_I_0) + percep_loss(It,g_I_1)
        
            smooth_loss1_0 = torch.mean(torch.abs(f1_0[:, :, :, :-1] - f1_0[:, :, :, 1:])) + torch.mean(torch.abs(f1_0[:, :, :-1, :] - f1_0[:, :, 1:, :]))
            smooth_loss0_1 = torch.mean(torch.abs(f0_1[:, :, :, :-1] - f0_1[:, :, :, 1:])) + torch.mean(torch.abs(f0_1[:, :, :-1, :] - f0_1[:, :, 1:, :]))
            Ls = smooth_loss1_0 + smooth_loss0_1
        
            L = 0.8*Lr + 0.005*Lp + 0.4*Lw + Ls
            valid_loss += L.item()
    ###################################################################################################
    elabsed_time = time.time() - start_time
    print('Epoch: {} \tElabsed Time : {} Seconds\nTraining Loss: {:.6f} \tValidation Loss: {:.6f}\tValidation Accuracy: {}%'.format(
                i+1, elabsed_time,train_loss, valid_loss, total_correct*100/len(valid_loader.dataset) ))

        # save model if validation loss has decreased
    if valid_loss <= valid_loss_min:
        valid_loss_min = valid_loss
        torch.save(flow_model.state_dict(), 'flow_model.pt')
        torch.save(arb_time_flow.state_dict(), 'arb_time_flow.pt')