# **IMPORTS**

In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.autograd import Variable
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.utils.data as data
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
from tqdm import tqdm
from numpy import *
from numpy.linalg import *
from scipy.special import factorial
from functools import reduce
import random
from torchvision import transforms
import matplotlib.pyplot as plt
import time
import gzip
import cv2
import math
import os
from PIL import Image
from skimage.metrics import structural_similarity as ssim
from skimage.transform import resize
import argparse
!pip install lpips
import codecs
import lpips

Collecting lpips
  Downloading lpips-0.1.4-py3-none-any.whl (53 kB)
     |████████████████████████████████| 53 kB 437 kB/s            
Installing collected packages: lpips
Successfully installed lpips-0.1.4


# **UTILITY**

In [2]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [3]:
def schedule_sampling(eta, itr, channel, batch_size):
    zeros = np.zeros((batch_size, args.total_length - args.input_length - 1, args.img_height, args.img_width, channel))
    if not args.scheduled_sampling:
        return 0.0, zeros

    if itr < args.sampling_stop_iter:
        eta -= args.sampling_changing_rate
    else:
        eta = 0.0
        
    if(configs.verbose or itr % 100 == 0):
        print('ETA: ', eta)
    random_flip = np.random.random_sample((batch_size, args.total_length - args.input_length - 1))
    true_token = (random_flip < eta) # replace 0.5 with eta
    ones = np.ones((args.img_height, args.img_width, channel))
    zeros = np.zeros((args.img_height, args.img_width, channel))
    
    mask = []
    for i in range(batch_size):
        for j in range(args.total_length - args.input_length - 1):
            if true_token[i, j]:
                mask.append(ones)
            else:
                mask.append(zeros)
                
    mask = np.array(mask)
    mask = np.reshape(mask, (batch_size, args.total_length - args.input_length - 1, args.img_height, args.img_width, channel))
    return eta, mask

In [4]:
class Resize2(object):

    def __call__(self, sample):
        imgs_out = np.zeros((
            sample.shape[0], 640, 480, sample.shape[3]))
        for i in range(sample.shape[0]):
            imgs_out[i,:,:,:] = resize(sample[i,:,:,:], imgs_out.shape[1:])
        return imgs_out

In [5]:
def evaluation_proper(model, test_loader, configs, out_len=10):
    print('Evaluating...')
    
    loss_fn = lpips.LPIPS(net='alex', spatial=True).to(configs.device)
    mse_list = np.empty((len(test_loader), out_len))
    mae_list = np.empty((len(test_loader), out_len))
    ssim_list = np.empty((len(test_loader), out_len))
    psnr_list = np.empty((len(test_loader), out_len))
    lpips_list = np.empty((len(test_loader), out_len))
    
    total_mse = 0
    total_mae = 0
    
    with torch.no_grad():
        #model.eval()
        for i, data in tqdm(enumerate(test_loader, 0), total=len(test_loader)):
#             if i == 1000:
#                 break
            batch_size = data.shape[0]
            real_input_flag = np.zeros(
                (batch_size,
                 configs.total_length - configs.input_length - 1,
                 configs.img_height // configs.patch_size,
                 configs.img_width // configs.patch_size,
                 configs.patch_size ** 2 * configs.img_channel))

            img_gen = model.test(data, real_input_flag)
            img_gen = img_gen.transpose(0, 1, 3, 4, 2)  # * 0.5 + 0.5
            test_ims = data.detach().cpu().numpy().transpose(0, 1, 3, 4, 2)  # * 0.5 + 0.5
            output_length = configs.total_length - configs.input_length
            output_length = min(output_length, configs.total_length - 1)
            target = data[:, configs.input_length:, :].detach().cpu().numpy().transpose(0, 1, 3, 4, 2)
            predictions = img_gen[:, -output_length:, :]
            
            p_min = predictions.min()
            p_max = predictions.max()
            n_min = 0
            n_max = 1
            
            #predictions = (predictions - p_min)/(p_max - p_min)*(n_max - n_min) + n_min
            predictions[predictions < 0.10] = 0
            predictions[predictions > 0.99] = 1
            
            if (i+1) % 50 == 0:
                print(target[0, 1, 40:42, 40:42, 0])
                print(predictions[0, 1, 40:42, 40:42, 0])
                fig, ax = plt.subplots(2, out_len, figsize=(25, 7))
                for i in range(2):
                    for j in range(out_len):
                        if i == 0:
                            ax[i][j].imshow(target[0][j])
                            ax[i][j].set_title('V Ground Truth')
                        if i == 1:
                            ax[i][j].imshow(predictions[0][j])
                            ax[i][j].set_title('V Generated')
                        ax[i][j].axis('off')
                plt.show()
            
            mse_batch = np.mean((predictions-target)**2 , axis=(0,1,4)).sum()
            mae_batch = np.mean(np.abs(predictions-target),  axis=(0,1,4)).sum() 
            total_mse += mse_batch
            total_mae += mae_batch
            
            for j in range(out_len):
                mse_list[i][j] = np.square(predictions[:,j,:,:,:] - target[:,j,:,:,:]).mean()
                mae_list[i][j] = np.abs(predictions[:,j,:,:,:] - target[:,j,:,:,:]).mean()
                ssim_list[i][j] = ssim(target[0,j,:,:,:], predictions[0,j,:,:,:], multichannel=True)
                psnr_list[i][j] = 20 * np.log10(1 / sqrt(mse_list[i][j]))
                t1 = torch.from_numpy((predictions[:,j,:,:,:] - 0.5) / 0.5).to(configs.device).permute((0, 3, 1, 2))
                t2 = torch.from_numpy((target[:,j,:,:,:] - 0.5) / 0.5).to(configs.device).permute((0, 3, 1, 2))
                d = loss_fn.forward(t1, t2)
                lpips_list[i][j] = d.mean().detach().cpu().numpy() * 100
                    
        #model.train()
        
    avg_mse_frame = mse_list.mean(axis=0)
    avg_mae_frame = mae_list.mean(axis=0)
    avg_ssim_frame = ssim_list.mean(axis=0)
    avg_psnr_frame = psnr_list.mean(axis=0)
    avg_lpips_frame = lpips_list.mean(axis=0)

    avg_mse = mse_list.mean()
    avg_mae = mae_list.mean()
    avg_ssim = ssim_list.mean()
    avg_psnr = psnr_list.mean()
    avg_lpips = lpips_list.mean()

    print('Eval MSE: ', total_mse/len(test_loader))
    print('Eval MAE: ', total_mae/len(test_loader))
    
    print(f'Avg-MSE: {avg_mse}\nMSE/Frame: {avg_mse_frame}')
    print(f'Avg-MAE: {avg_mae}\nMAE/Frame: {avg_mae_frame}')
    print(f'Avg-SSIM: {avg_ssim}\nSSIM/Frame: {avg_ssim_frame}')
    print(f'Avg-PSNR: {avg_psnr}\nPSNR/Frame: {avg_psnr_frame}')
    print(f'Avg-LPIPS: {avg_lpips}\nLPIPS/Frame: {avg_lpips_frame}')
    
    return avg_mse

# **DATA LOADER**

In [6]:
class Norm(object):
    def __init__(self, max=255):
        self.max = max

    def __call__(self, sample):
        video_x = sample
        new_video_x = video_x / self.max
        return new_video_x


class ToTensor(object):

    def __call__(self, sample):
        video_x = sample
        video_x = video_x.transpose((0, 3, 1, 2))
        video_x = np.array(video_x)
        return torch.from_numpy(video_x).float()
    

class Resize(object):

    def __call__(self, sample):
        imgs_out = np.zeros((
            sample.shape[0], configs.img_height, configs.img_width, sample.shape[3]))
        for i in range(sample.shape[0]):
            imgs_out[i,:,:,:] = resize(sample[i,:,:,:], imgs_out.shape[1:])
        return imgs_out
    
class Resize2(object):

    def __call__(self, sample):
        imgs_out = np.zeros((
            sample.shape[0], 640, 480, sample.shape[3]))
        for i in range(sample.shape[0]):
            imgs_out[i,:,:,:] = resize(sample[i,:,:,:], imgs_out.shape[1:])
        return imgs_out

In [7]:
class TimeSeriesDatasetNpz(data.Dataset):
    def __init__(self, root_dir, n_frames_input=10, n_frames_output=10):
        self.n_frames_in = n_frames_input
        self.n_frames_out = n_frames_output
        n_frames = n_frames_input + n_frames_output
        
        self.file = np.load(root_dir).transpose(1,0,2,3)[..., np.newaxis].transpose(0,1,4,2,3)
        #self.file = np.load(root_dir).transpose(1,0,4,2,3)
            
            
    def __len__(self):
        return len(self.file)

    def __getitem__(self, index):
        clips = torch.from_numpy(self.file[index])
        clips = clips.type(torch.float32)
        clips = (clips / 255)
        return clips

In [8]:
class TimeSeriesDataset(data.Dataset):
    def __init__(self, root_dir, n_frames_input=10, n_frames_output=10):
        view_type='090'
        random.seed(1000)
        self.n_frames_in = n_frames_input
        self.n_frames_out = n_frames_output
        n_frames = n_frames_input + n_frames_output
        subject_dirs = [d for d in os.listdir(root_dir)] # [01, 02, 03]
        subject_dirs = [os.path.join(root_dir, subject_dirs[i]) for i in range(len(subject_dirs))] # root/01/01/
        seq_type_dirs = [os.path.join(d, sd) for d in subject_dirs for sd in os.listdir(d)] # [root/01/01/bg1, root/01/01/bg2]
        view_type_dirs = sorted([[os.path.join(d, sd)] for d in seq_type_dirs for sd in os.listdir(d) if sd == view_type])

        self.specific_view_files = []
        print(len(view_type_dirs))
        
        for d in view_type_dirs:
            self.specific_view_files.append(sorted([os.path.join(d[0], f) for f in os.listdir(d[0])]))
        
        self.nframes_list = []
        for f in self.specific_view_files[:50]:
            for i in range(len(f)-n_frames):
                self.nframes_list.append(f[i:i+n_frames])
        
        self.trans_norm = transforms.Compose([Norm()])
        self.trans_tensor = transforms.Compose([ToTensor()])
        self.trans_resize = transforms.Compose([Resize()])
        
        print(len(self.nframes_list))
            
            
    def __len__(self):
        return len(self.nframes_list)

    def __getitem__(self, index):
        data_seq = np.ndarray(shape=(len(self.nframes_list[index]), 64, 64), dtype=np.uint8)
        
        for i, f in enumerate(self.nframes_list[index]):
            data_seq[i, :] = cv2.resize(np.array(Image.open(f)), (64, 64))
            
        data_seq = data_seq[..., np.newaxis]
        input = self.trans_norm(data_seq)  
        #input = self.trans_resize(input)
        #input = reshape_patch(input, configs.patch_size)
        input = self.trans_tensor(input)
        return input

## TEST CASE

In [9]:
# td = TimeSeriesDataset(root_dir='../../input/ethzzz/dataset_ETHZ/seq1', n_frames_input=10, n_frames_output=10)
# train_loader = torch.utils.data.DataLoader(dataset=td, batch_size=3, shuffle=True, num_workers=1)

In [10]:
# z = next(iter(train_loader))
# print(z.shape)
# print(torch.max(z), torch.min(z))

In [11]:
# plt.figure()
# plt.imshow(z[0][0].permute(1,2,0)) 
# plt.show()

# **DEPTH**

In [12]:
# !pip install gdown

In [13]:
# import gdown

# !git clone 'https://github.com/shariqfarooq123/AdaBins'

In [14]:
# %cd AdaBins

In [15]:
# !gdown https://drive.google.com/uc?id=1lvyZZbC9NLcS8a__YPcUP7rDiIpbRpoF
# !mkdir pretrained
# !mv AdaBins_nyu.pt pretrained/AdaBins_nyu.pt

In [16]:
# from infer import InferenceHelper

# infer_helper = InferenceHelper(dataset='nyu')

In [17]:
# [os.path.join('../../input/ethzzz/dataset_ETHZ/seq1', d) for d in os.listdir('../../input/ethzzz/dataset_ETHZ/seq1')][:5]

In [18]:
# !mkdir output_depth

In [19]:
#infer_helper.predict_dir("../../input/ethzzz/dataset_ETHZ/seq1/p001/", "../output_depth")

In [20]:
# example_rgb_batch = next(iter(train_loader)).to(configs.device)
# print(example_rgb_batch[:,0,:].shape)
# bin_centers, predicted_depth = infer_helper.predict(example_rgb_batch[:,0,:])

In [21]:
# plt.imshow(example_rgb_batch[0,0,:].cpu().permute(1, 2, 0))

In [22]:
# plt.imshow(predicted_depth[0][0], cmap='plasma')
# plt.show()

# **MODELS**

## MAU-CELL

In [23]:
ix = 0

In [24]:
class MAUCell(nn.Module):
    def __init__(self, in_channel, num_hidden, height, width, filter_size, stride, tau, cell_mode):
        super(MAUCell, self).__init__()
        
        self.num_hidden = num_hidden
        self.padding = (filter_size[0] // 2, filter_size[1] // 2)
        self.cell_mode = cell_mode
        self.d = num_hidden * height * width
        self.tau = tau
        self.states = ['residual', 'normal']
        if not self.cell_mode in self.states:
            raise AssertionError
        self.conv_t = nn.Sequential(
            nn.Conv2d(in_channel, 4 * num_hidden, kernel_size=filter_size, stride=stride, padding=self.padding),
            nn.LayerNorm([4 * num_hidden, height, width])
        )
        self.conv_t_next = nn.Sequential(
            nn.Conv2d(in_channel, num_hidden, kernel_size=filter_size, stride=stride, padding=self.padding),
            nn.LayerNorm([num_hidden, height, width])
        )
        
        self.conv_s = nn.Sequential(
            nn.Conv2d(num_hidden, 4 * num_hidden, kernel_size=filter_size, stride=stride, padding=self.padding),
            nn.LayerNorm([4 * num_hidden, height, width])
        )
        
        self.conv_s_next = nn.Sequential(
            nn.Conv2d(num_hidden, num_hidden, kernel_size=filter_size, stride=stride, padding=self.padding),
            nn.LayerNorm([num_hidden, height, width])
        )
        
        self.conv_t_i = nn.Sequential(
            nn.Conv2d(num_hidden, num_hidden, kernel_size=filter_size, stride=stride, padding=self.padding),
            nn.LayerNorm([num_hidden, height, width])
        )
        
        self.alpha_s = nn.Parameter(torch.randn(1))
        self.alpha_t = nn.Parameter(torch.randn(1))
        
        self.softmax = nn.Softmax(dim=0)

        
    def forward(self, T_t, S_t, t_att, s_att, s_pixel_att):
        global ix
        s_next = self.conv_s_next(S_t)
        t_next = self.conv_t_next(T_t)

        weights_list = []
        for i in range(self.tau):
            weights_list.append((s_att[i] * s_next).sum(dim=(1, 2, 3)) / math.sqrt(self.d))
        weights_list = torch.stack(weights_list, dim=0)
        weights_list = torch.reshape(weights_list, (*weights_list.shape, 1, 1, 1))
        weights_list = self.softmax(weights_list)

        T_trend = t_att * weights_list
        T_trend = T_trend.sum(dim=0)
        t_att_gate = torch.sigmoid(t_next)
        s_att_gate = torch.sigmoid(s_next)
        
        T_fusion = T_t * t_att_gate + (1 - t_att_gate) * T_trend
        #S_fusion = S_t * s_att_gate + (1 - s_att_gate) * torch.sigmoid(s_pixel_att) * T_trend
        #S_fusion = S_t * s_att_gate + (1 - s_att_gate) * s_pixel_att
        S_fusion = S_t * torch.sigmoid(s_pixel_att)
        
        T_concat = self.conv_t(T_fusion)
        S_concat = self.conv_s(S_fusion)

        t_i, t_r, t_t, t_s = torch.split(T_concat, self.num_hidden, dim=1)
        s_i, s_r, s_t, s_s = torch.split(S_concat, self.num_hidden, dim=1)

#         T_c = torch.tanh(T_fusion*t_r)
#         S_c = torch.tanh(S_fusion*s_r)
        
        T_i = torch.tanh(t_i)
        S_i = torch.tanh(s_i)
        
        T_r = torch.sigmoid(t_r)
        S_r = torch.sigmoid(t_r)
        
        T_t = torch.sigmoid(t_t)
        S_s = torch.sigmoid(s_s)
        
        T_s = torch.sigmoid(t_s)
        S_t = torch.sigmoid(s_t)

        T_new_1 = T_r * T_i + S_t * T_fusion
        S_new_1 = S_r * S_i + T_s * S_fusion

        T_new_2 = T_r * t_i + (1 - T_r) * s_t
        S_new_2 = S_r * s_i + (1 - S_r) * t_s
        
        if ix % 9000 == 0:
            print(self.alpha_s.item(), self.alpha_t.item())
                
        out_S = self.alpha_s*S_new_1 + (1-self.alpha_s)*S_new_2
        out_T = self.alpha_t*T_new_1 + (1-self.alpha_t)*T_new_2

        #if self.cell_mode == 'residual':
           # S_new = S_t + S_new
        ix += 1
        return out_T, out_S
        #return T_new_2, S_new_2
        #return T_new_1, S_new_1


## MAU

In [25]:
class RNN(nn.Module):
    def __init__(self, num_layers, num_hidden, configs):
        super(RNN, self).__init__()
        # print(configs.srcnn_tf)
        self.configs = configs
        self.frame_channel = configs.img_channel
        self.num_layers = num_layers
        self.num_hidden = num_hidden
        self.tau = configs.tau
        self.cell_mode = configs.cell_mode
        self.states = ['recall', 'normal']
        if not self.configs.model_mode in self.states:
            raise AssertionError
        # self.time = 2
        cell_list = []

        width = configs.img_width // configs.sr_size
        height = configs.img_height // configs.sr_size
        # print(width)

        for i in range(num_layers):
            in_channel = num_hidden[i - 1]
            cell_list.append(
                MAUCell(in_channel, num_hidden[i], height, width, configs.filter_size,
                        configs.stride, self.tau, self.cell_mode)
            )
        self.cell_list = nn.ModuleList(cell_list)

        # Encoder
        n = int(math.log2(configs.sr_size))
        encoders = []
        encoder = nn.Sequential()
        encoder.add_module(name='encoder_t_conv{0}'.format(-1),
                           module=nn.Conv2d(in_channels=self.frame_channel,
                                            out_channels=self.num_hidden[0],
                                            stride=1,
                                            padding=0,
                                            kernel_size=1))
        encoder.add_module(name='relu_t_{0}'.format(-1),
                           module=nn.LeakyReLU(0.2))
        encoders.append(encoder)
        for i in range(n):
            encoder = nn.Sequential()
            encoder.add_module(name='encoder_t{0}'.format(i),
                               module=nn.Conv2d(in_channels=self.num_hidden[0],
                                                out_channels=self.num_hidden[0],
                                                stride=(2, 2),
                                                padding=(1, 1),
                                                kernel_size=(3, 3)
                                                ))
            # self.encoder_t.add_module(name='gn_t{0}'.format(i),
            #                           module=nn.GroupNorm(4, self.frame_channel))
            encoder.add_module(name='encoder_t_relu{0}'.format(i),
                               module=nn.LeakyReLU(0.2))
            encoders.append(encoder)
        self.encoders = nn.ModuleList(encoders)

        # Decoder
        decoders = []

        for i in range(n - 1):
            decoder = nn.Sequential()
            decoder.add_module(name='c_decoder{0}'.format(i),
                               module=nn.ConvTranspose2d(in_channels=self.num_hidden[-1],
                                                         out_channels=self.num_hidden[-1],
                                                         stride=(2, 2),
                                                         padding=(1, 1),
                                                         kernel_size=(3, 3),
                                                         output_padding=(1, 1)
                                                         ))
            # self.decoder_s.add_module(name='gn_decoder_s{0}'.format(i),
            #                           module=nn.GroupNorm(4, self.frame_channel))
            decoder.add_module(name='c_decoder_relu{0}'.format(i),
                               module=nn.LeakyReLU(0.2))
            decoders.append(decoder)

        if n > 0:
            decoder = nn.Sequential()
            decoder.add_module(name='c_decoder{0}'.format(n - 1),
                               module=nn.ConvTranspose2d(in_channels=self.num_hidden[-1],
                                                         out_channels=self.num_hidden[-1],
                                                         stride=(2, 2),
                                                         padding=(1, 1),
                                                         kernel_size=(3, 3),
                                                         output_padding=(1, 1)
                                                         ))
            decoders.append(decoder)
        self.decoders = nn.ModuleList(decoders)

        self.srcnn = nn.Sequential(
            nn.Conv2d(self.num_hidden[-1], self.frame_channel, kernel_size=1, stride=1, padding=0)
        )
        self.merge = nn.Conv2d(self.num_hidden[-1] * 2, self.num_hidden[-1], kernel_size=1, stride=1, padding=0)
        self.conv_last_sr = nn.Conv2d(self.frame_channel * 2, self.frame_channel, kernel_size=1, stride=1, padding=0)


    def forward(self, frames, mask_true, verbose=False):
        # print('ok')
        mask_true = mask_true.permute(0, 1, 4, 2, 3).contiguous()
        if(verbose):
            print('MT Permuted to: ',  mask_true.shape)
        batch_size = frames.shape[0]
        height = frames.shape[3] // self.configs.sr_size
        width = frames.shape[4] // self.configs.sr_size
        frame_channels = frames.shape[2]
        next_frames = []
        T_t = []
        T_pre = []
        S_pre = []
        # H_t = []
        x_gen = None
        if(verbose):
            print('Num Layers: ', self.num_layers)
            print('Num Hidden: ', self.num_hidden)
            print('TAU: ', self.tau)
        for layer_idx in range(self.num_layers):
            tmp_t = []
            tmp_s = []
            if layer_idx == 0:
                in_channel = self.num_hidden[layer_idx]
            else:
                in_channel = self.num_hidden[layer_idx - 1]
            for i in range(self.tau):
                if(verbose):
                    if i==2:
                        print('tmp_t[1]', tmp_t[1].shape)
                tmp_t.append(torch.zeros([batch_size, in_channel, height, width]).to(self.configs.device))
                tmp_s.append(torch.zeros([batch_size, in_channel, height, width]).to(self.configs.device))
            T_pre.append(tmp_t)
            S_pre.append(tmp_s)

        if(verbose):
            print('len T_pre', len(T_pre))
            print('len T_pre[1]', len(T_pre[1]))

        S_t_previ = torch.zeros([batch_size, in_channel, height, width]).to(self.configs.device)
            
        for t in range(self.configs.total_length - 1):
            if t < self.configs.input_length:
                net = frames[:, t]
                if(verbose):
                    print('Net frames[:, t]', frames[:, t].shape)
            else:
                time_diff = t - self.configs.input_length
                net = mask_true[:, time_diff] * frames[:, t] + (1 - mask_true[:, time_diff]) * x_gen
                if(verbose):
                    print('Net', net.shape)

            frames_feature = net
            frames_feature_encoded = []
            if(verbose):
                print('Len Encoders', len(self.encoders))
            for i in range(len(self.encoders)):
                frames_feature = self.encoders[i](frames_feature)
                if(verbose):
                    print('Frames_feature', i, frames_feature.shape)
                frames_feature_encoded.append(frames_feature)
            if t == 0:
                for i in range(self.num_layers):
                    zeros = torch.zeros([batch_size, self.num_hidden[i], height, width]).to(self.configs.device)
                    T_t.append(zeros)
                    # print('ok')
            S_t = frames_feature
            if(verbose):
                print('S_t in', S_t.shape)
            
            S_pixel_att = torch.sum((S_t - S_t_previ)**2)
            S_t_previ = S_t
            
                
            for i in range(self.num_layers):
                t_att = T_pre[i][-self.tau:]
                t_att = torch.stack(t_att, dim=0)
                s_att = S_pre[i][-self.tau:]
                s_att = torch.stack(s_att, dim=0)
                S_pre[i].append(S_t)
                if i < 2 and verbose:
                    print('T_t', len(T_t))
                    print('T_t[i], S_t, t_att, s_att', T_t[i].shape, S_t.shape, t_att.shape, s_att.shape)
                T_t[i], S_t = self.cell_list[i](T_t[i], S_t, t_att, s_att, S_pixel_att)
                T_pre[i].append(T_t[i])
                # S_pre[i].append(S_t)
            out = S_t
            if(verbose):
                print('S_t out', S_t.shape)
            # out = self.merge(torch.cat([T_t[-1], S_t], dim=1))
            frames_feature_decoded = []
            for i in range(len(self.decoders)):
                out = self.decoders[i](out)
                if(verbose):
                    print('S_t out', i, out.shape)
                # print("ok")
                if self.configs.model_mode == 'recall':
                    # print('unet')
                    out = out + frames_feature_encoded[-2 - i]
                    if(verbose):
                        print('S_t out unet', i, out.shape)
            # out = self.decoder(out)

            x_gen = self.srcnn(out)
            next_frames.append(x_gen)
            if(verbose):
                print('x_gen', x_gen.shape)
                print('len next_frames', len(next_frames))
        if(verbose):
            print('len next_frames FULL', len(next_frames))
        next_frames = torch.stack(next_frames, dim=0)
        if(verbose):
            print('next_frames Tensor', next_frames.shape)
        next_frames = next_frames.permute(1, 0, 2, 3, 4).contiguous()
        if(verbose):
            print('next_frames Tensor Permuted', next_frames.shape)
        return next_frames

## DISCRIMINATOR

In [26]:
class FDU(nn.Module):
    def __init__(self, num_layers, num_hidden, configs):
        super(FDU, self).__init__()

        self.configs = configs
        self.frame_channel = configs.img_channel
        self.num_layers = num_layers
        self.num_hidden = num_hidden
        self.tau = configs.tau
        self.cell_mode = configs.cell_mode

        cell_list = []

        width = configs.img_width // configs.sr_size
        height = configs.img_height // configs.sr_size

        for i in range(num_layers):
            in_channel = num_hidden[i - 1]
            cell_list.append(
                MAUCell(in_channel, num_hidden[i], height, width, configs.filter_size,
                        configs.stride, self.tau, self.cell_mode)
            )
        self.cell_list = nn.ModuleList(cell_list)

        # Encoder
        n = int(math.log2(configs.sr_size))
        encoders = []
        encoder = nn.Sequential()
        encoder.add_module(name='encoder_t_conv{0}'.format(-1),
                           module=nn.Conv2d(in_channels=self.frame_channel,
                                            out_channels=self.num_hidden[0],
                                            stride=1,
                                            padding=0,
                                            kernel_size=1))
        encoder.add_module(name='relu_t_{0}'.format(-1),
                           module=nn.LeakyReLU(0.2))
        encoders.append(encoder)
        for i in range(n):
            encoder = nn.Sequential()
            encoder.add_module(name='encoder_t{0}'.format(i),
                               module=nn.Conv2d(in_channels=self.num_hidden[0],
                                                out_channels=self.num_hidden[0],
                                                stride=(2, 2),
                                                padding=(1, 1),
                                                kernel_size=(3, 3)
                                                ))
            # self.encoder_t.add_module(name='gn_t{0}'.format(i),
            #                           module=nn.GroupNorm(4, self.frame_channel))
            encoder.add_module(name='encoder_t_relu{0}'.format(i),
                               module=nn.LeakyReLU(0.2))
            encoders.append(encoder)
        self.encoders = nn.ModuleList(encoders)


    def forward(self, frames, verbose=False):
        if(verbose):
            print('MT Permuted to: ',  mask_true.shape)
        batch_size = frames.shape[0]
        height = frames.shape[3] // self.configs.sr_size
        width = frames.shape[4] // self.configs.sr_size
        frame_channels = frames.shape[2]
        next_frames = []
        next_memory = []
        T_t = []
        T_pre = []
        S_pre = []
        # H_t = []
        x_gen = None
        if(verbose):
            print('Num Layers: ', self.num_layers)
            print('Num Hidden: ', self.num_hidden)
            print('TAU: ', self.tau)
        for layer_idx in range(self.num_layers):
            tmp_t = []
            tmp_s = []
            if layer_idx == 0:
                in_channel = self.num_hidden[layer_idx]
            else:
                in_channel = self.num_hidden[layer_idx - 1]
            for i in range(self.tau):
                if(verbose):
                    if i==2:
                        print('tmp_t[1]', tmp_t[1].shape)
                tmp_t.append(torch.zeros([batch_size, in_channel, height, width]).to(self.configs.device))
                tmp_s.append(torch.zeros([batch_size, in_channel, height, width]).to(self.configs.device))
            T_pre.append(tmp_t)
            S_pre.append(tmp_s)

        if(verbose):
            print('len T_pre', len(T_pre))
            print('len T_pre[1]', len(T_pre[1]))

        S_t_previ = torch.zeros([batch_size, in_channel, height, width]).to(self.configs.device)
            
        for t in range(self.configs.total_length - 1):
            net = frames[:, t]
            frames_feature = net
            frames_feature_encoded = []
            
            if(verbose):
                print('Len Encoders', len(self.encoders))
                
            for i in range(len(self.encoders)):
                frames_feature = self.encoders[i](frames_feature)
                if(verbose):
                    print('Frames_feature', i, frames_feature.shape)
                frames_feature_encoded.append(frames_feature)
            if t == 0:
                for i in range(self.num_layers):
                    zeros = torch.zeros([batch_size, self.num_hidden[i], height, width]).to(self.configs.device)
                    T_t.append(zeros)
                    
            S_t = frames_feature
            if(verbose):
                print('S_t in', S_t.shape)
            
            S_pixel_att = torch.sum((S_t - S_t_previ)**2)
            S_t_previ = S_t
            
                
            for i in range(self.num_layers):
                t_att = T_pre[i][-self.tau:]
                t_att = torch.stack(t_att, dim=0)
                s_att = S_pre[i][-self.tau:]
                s_att = torch.stack(s_att, dim=0)
                S_pre[i].append(S_t)
                if i < 2 and verbose:
                    print('T_t', len(T_t))
                    print('T_t[i], S_t, t_att, s_att', T_t[i].shape, S_t.shape, t_att.shape, s_att.shape)
                T_t[i], S_t = self.cell_list[i](T_t[i], S_t, t_att, s_att, S_pixel_att)
                T_pre[i].append(T_t[i])

            next_frames.append(S_t)
            next_memory.append(T_t[-1])
            
        next_frames = torch.stack(next_frames, dim=0).permute(1, 0, 2, 3, 4)
        next_memory = torch.stack(next_memory, dim=0).permute(1, 0, 2, 3, 4)
        next_all = torch.cat([next_frames, next_memory], dim=2)
        return next_all

# **TRAIN TEST WRAPPER**

In [27]:
def train_wrapper(model):
    begin = 0

    if args.pretrained_model_g:
        model.load(args.pretrained_model_g, args.pretrained_model_d)
        # begin = int(args.pretrained_model.split('-')[-1])
        
    # DATASET
    dataset = TimeSeriesDataset(root_dir=configs.data_train_path, n_frames_input=10, n_frames_output=10)
    
    # DATA LOADER + SPLIT
    validation_split = .3
    shuffle_dataset = True
    random_seed= 1000

    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    split = int(np.floor(validation_split * dataset_size))
    if shuffle_dataset :
        np.random.seed(random_seed)
        np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]

    # Creating PT data samplers and loaders:
    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(val_indices)
    train_loader = torch.utils.data.DataLoader(dataset, args.batch_size, sampler=train_sampler, num_workers=2, pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(dataset, batch_size=4, sampler=valid_sampler)

    losses_l1 = []
    losses_l2 = []
    
    eta = args.sampling_start_value
    eta -= (begin * args.sampling_changing_rate)
    itr = begin

#     print('Validate:')
    evaluation_proper(model, valid_loader, args, args.output_length)
    #return
    
    for epoch in range(0, args.max_epoches):
        for x in tqdm(train_loader, total=len(train_loader)):
            batch_size = x.shape[0]  # (bs, frames, c, h, w)
            eta, mask = schedule_sampling(eta, itr, args.img_channel, batch_size)
                
            _, loss_l1, loss_l2 = model.train(x, mask, itr, next(iter(valid_loader)), epoch)
            
            if itr % configs.display_interval == 0:
                print('Step: ' + str(itr), 'T L1 loss: ' + str(loss_l1), 'T L2 loss: ' + str(loss_l2))
                
            losses_l1.append(loss_l1)
            losses_l2.append(loss_l2)
            
            if itr % configs.plot_interval == 0:
                fig, ax = plt.subplots(2, 1, figsize=(13, 5))
                a = ax.flatten()
                a[0].plot(losses_l1, 'r')
                a[0].set_title('Loss L1 (D)')
                a[1].plot(losses_l2, 'r')
                a[1].set_title('Loss L2 (G)')
                plt.show()
            
            if itr % args.snapshot_interval == 0 and itr > begin:
                model.save(itr)
            itr += 1
            
        if epoch >=8:
            print('Validate:')
            evaluation_proper(model, valid_loader, args, args.output_length)
            #model.save(itr)


def test_wrapper(model, valid_loader):
    model.load(args.pretrained_model)

    for i in range(1):
        trainer.test(model, valid_loader, args, itr)

# **MODEL FACTORY**

In [28]:
class Model(object):
    def __init__(self, configs):
        self.configs = configs
        self.patch_height = configs.img_height
        self.patch_width = configs.img_width
        self.patch_channel = configs.img_channel
        self.num_layers = configs.num_layers
        networks_map = {'mau': RNN}
        self.num_hidden = [configs.num_hidden for i in range(configs.num_layers)]
        if configs.model_name in networks_map:
            Network = networks_map[configs.model_name]
            self.network = Network(self.num_layers, self.num_hidden, configs).to(configs.device)
        else:
            raise ValueError('Name of network unknown %s' % configs.model_name)

        if args.gan:
            self.disc = FDU(self.num_layers, self.num_hidden, configs).to(configs.device)
            self.disc_optimizer = torch.optim.Adam(self.disc.parameters(), lr=configs.lr) #, betas=(0.5, 0.999))
            #self.alpha = torch.linspace(1, 0, configs.max_epoches).to(configs.device)
            #self.beta = torch.linspace(0, 1, configs.max_epoches).to(configs.device)
            self.alpha = [1,1,1, 10, 100, 100, 100]
            self.beta = [1,1,1, 0.1, 0.001, 0.001, 0.001]
            
        self.optimizer = Adam(self.network.parameters(), lr=configs.lr)
        self.MSE_criterion = nn.MSELoss()
        self.L1_loss = nn.L1Loss()
        self.loss_bce = nn.BCEWithLogitsLoss()

    def save(self, itr):
        stats_g = {'net_param': self.network.state_dict()}
        stats_d = {'net_param': self.disc.state_dict()}
        checkpoint_path_g = os.path.join(self.configs.save_dir, 'model_g.ckpt' + '-' + str(itr))
        checkpoint_path_d = os.path.join(self.configs.save_dir, 'model_d.ckpt' + '-' + str(itr))
        torch.save(stats_g, checkpoint_path_g)
        torch.save(stats_d, checkpoint_path_d)
        #print("Save predictive model to %s" % checkpoint_path)

    def load(self, pm_checkpoint_path_g, pm_checkpoint_path_d):
        print('Load predictive model:', pm_checkpoint_path_g)
        stats = torch.load(pm_checkpoint_path_g, map_location=torch.device(self.configs.device))
        self.network.load_state_dict(stats['net_param'])
        stats = torch.load(pm_checkpoint_path_d, map_location=torch.device(self.configs.device))
        self.disc.load_state_dict(stats['net_param'])

    def train(self, data, mask, itr, val, ei):
        frames = data
        loss_d = 0
        loss_g = 0
        self.network.train()
        val_tensor = torch.FloatTensor(val).to(self.configs.device)
        frames_tensor = torch.FloatTensor(frames).to(self.configs.device)
        mask_tensor = torch.FloatTensor(mask).to(self.configs.device)

        if(self.configs.verbose):
            print('FT', frames_tensor.shape)
            print('MT', mask_tensor.shape)

        next_frames = self.network(frames_tensor, mask_tensor)
        if(self.configs.verbose):
            print('Next Frames', next_frames.shape)

        ground_truth = frames_tensor
        if(self.configs.verbose):
            print('Ground', ground_truth[:, 1:].shape)
            
        if args.gan:    
            d_fake = self.disc(next_frames.detach()) #decoder_input.detach(), 
            d_real = self.disc(frames_tensor[:, 1:]) #decoder_input.detach(), 
            loss_d_real = self.loss_bce(d_real, torch.ones_like(d_real))
            loss_d_fake = self.loss_bce(d_fake, torch.zeros_like(d_fake))
            loss_d = loss_d_real + loss_d_fake

            
        if itr % configs.plot_interval == 0:
            print('Epoch:', ei)
            with torch.no_grad():
                self.network.eval()
                x = frames_tensor[0][0:configs.input_length]
                y = frames_tensor[0][configs.input_length:]
                g = next_frames[0][configs.input_length-1:]
                m = mask_tensor[0]
                fig, ax = plt.subplots(4, configs.input_length, figsize=(25, 10))
                for i in range(4):
                    for j in range(configs.input_length):
                        if i == 0:
                            ax[i][j].imshow(x[j].to('cpu').permute(1, 2, 0), cmap='gray')
                            ax[i][j].set_title('T Input')
                        if i == 1:
                            if j == configs.input_length-1:
                                ax[i][j].axis('off')
                                continue
                            ax[i][j].imshow(m[j].to('cpu'))
                            ax[i][j].set_title('T Mask')
                        if i == 2:
                            ax[i][j].imshow(y[j].to('cpu').permute(1, 2, 0), cmap='gray')
                            ax[i][j].set_title('T Ground Truth')
                        if i == 3:
                            ax[i][j].imshow(g[j].detach().to('cpu').permute(1, 2, 0), cmap='gray')
                            ax[i][j].set_title('T Generated')
                        ax[i][j].axis('off')


                x = val_tensor[0][0:configs.input_length]
                y = val_tensor[0][configs.input_length:]
                mask = torch.zeros_like(mask_tensor[0]).unsqueeze(0).to(configs.device)
                next_frameszz = self.network(val_tensor, mask)
                m = mask[0]
                g = next_frameszz[0][configs.input_length-1:]
                fig, ax = plt.subplots(4, configs.input_length, figsize=(25, 10))
                for i in range(4):
                    for j in range(configs.input_length):
                        if i == 0:
                            ax[i][j].imshow(x[j].to('cpu').permute(1, 2, 0), cmap='gray')
                            ax[i][j].set_title('V Input')
                        if i == 1:
                            if j == configs.input_length-1:
                                ax[i][j].axis('off')
                                continue
                            ax[i][j].imshow(m[j].to('cpu'))
                            ax[i][j].set_title('V Mask')
                        if i == 2:
                            ax[i][j].imshow(y[j].to('cpu').permute(1, 2, 0), cmap='gray')
                            ax[i][j].set_title('V Ground Truth')
                        if i == 3:
                            ax[i][j].imshow(g[j].detach().to('cpu').permute(1, 2, 0), cmap='gray')
                            ax[i][j].set_title('V Generated')
                        ax[i][j].axis('off')
                
                self.network.train()
        
        if args.gan:
            self.disc_optimizer.zero_grad()
            loss_d.backward()
            self.disc_optimizer.step() 

            
            d_fake = self.disc(next_frames) 
            loss_g = self.loss_bce(d_fake, torch.ones_like(d_fake))
            
        batch_size = next_frames.shape[0]

        loss_l1 = self.L1_loss(next_frames, ground_truth[:, 1:])
        loss_l2 = self.MSE_criterion(next_frames, ground_truth[:, 1:])
        
        alp = self.alpha[min(len(self.alpha)-1, ei)]
        bet = self.beta[min(len(self.beta)-1, ei)]
        
        if itr % configs.plot_interval == 0:
            print('Aplha: ', alp)
            print('Beta: ', bet)
        
        if args.gan:
            #loss_gen = loss_l2*alp + loss_l1 + loss_g*bet
            loss_gen = loss_g*1e-6 + loss_l2 + loss_l1*0.1
            #loss_gen = loss_g*1e-6*alp + loss_l2*bet + loss_l1*0.5
            #loss_gen = loss_g*1e-6*alp + loss_l2*bet + loss_l1*0.5
        else:
            loss_gen = loss_l2 # + loss_l1*0.5
        
        self.optimizer.zero_grad()
        loss_gen.backward()
        self.optimizer.step()

        if itr >= self.configs.sampling_stop_iter and itr % self.configs.delay_interval == 0:
            self.scheduler.step()
            print('LR decay to:%.8f', self.optimizer.param_groups[0]['lr'])
            
        return next_frames, loss_d.item(), loss_gen.item()

    def test(self, data, mask):
        frames = data
        self.network.eval()
        frames_tensor = torch.FloatTensor(frames).to(self.configs.device)
        mask_tensor = torch.FloatTensor(mask).to(self.configs.device)
        next_frames = self.network(frames_tensor, mask_tensor)
        return next_frames.detach().cpu().numpy()