In [None]:
!which python3

In [None]:
import cv2
import torchinfo
import numpy as np
from torch.optim import Adam
from torch.utils.data import DataLoader

In [None]:
import io
import imageio
from tqdm import tqdm
from avi_r import AVIReader
import matplotlib.pyplot as plt
from ipywidgets import widgets, HBox

In [None]:
import os
import torch
import torch.nn as nn

In [None]:
class ConvLSTMCell(nn.Module):

    def __init__(self, in_channels, out_channels, 
    kernel_size, padding, activation, frame_size):

        super(ConvLSTMCell, self).__init__()  

        if activation == "tanh":
            self.activation = torch.tanh 
        elif activation == "relu":
            self.activation = torch.relu
        
        # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
        self.conv = nn.Conv2d(
            in_channels=in_channels + out_channels, 
            out_channels=4 * out_channels, 
            kernel_size=kernel_size, 
            padding=padding)           

        # Initialize weights for Hadamard Products
        self.W_ci = nn.Parameter(torch.Tensor(out_channels, *frame_size))
        self.W_co = nn.Parameter(torch.Tensor(out_channels, *frame_size))
        self.W_cf = nn.Parameter(torch.Tensor(out_channels, *frame_size))

    def forward(self, X, H_prev, C_prev):

        # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
        conv_output = self.conv(torch.cat([X, H_prev], dim=1))

        # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
        i_conv, f_conv, C_conv, o_conv = torch.chunk(conv_output, chunks=4, dim=1)

        input_gate = torch.sigmoid(i_conv + self.W_ci * C_prev )
        forget_gate = torch.sigmoid(f_conv + self.W_cf * C_prev )

        # Current Cell output
        C = forget_gate*C_prev + input_gate * self.activation(C_conv)

        output_gate = torch.sigmoid(o_conv + self.W_co * C )

        # Current Hidden State
        H = output_gate * self.activation(C)

        return H, C

In [None]:
class ConvLSTM(nn.Module):

    def __init__(self, in_channels, out_channels, 
    kernel_size, padding, activation, frame_size):

        super(ConvLSTM, self).__init__()

        self.out_channels = out_channels

        # We will unroll this over time steps
        self.convLSTMcell = ConvLSTMCell(in_channels, out_channels, 
        kernel_size, padding, activation, frame_size)

    def forward(self, X):

        # X is a frame sequence (batch_size, num_channels, seq_len, height, width)

        # Get the dimensions
        batch_size, _, seq_len, height, width = X.size()

        # Initialize output
        output = torch.zeros(batch_size, self.out_channels, seq_len, 
        height, width, device=device)
        
        # Initialize Hidden State
        H = torch.zeros(batch_size, self.out_channels, 
        height, width, device=device)

        # Initialize Cell Input
        C = torch.zeros(batch_size,self.out_channels, 
        height, width, device=device)

        # Unroll over time steps
        for time_step in range(seq_len):

            H, C = self.convLSTMcell(X[:,:,time_step], H, C)

            output[:,:,time_step] = H

        return output

In [None]:
class Seq2Seq(nn.Module):

    def __init__(self, num_channels, num_kernels, kernel_size, padding, 
    activation, frame_size, num_layers):

        super(Seq2Seq, self).__init__()

        self.sequential = nn.Sequential()

        # Add First layer (Different in_channels than the rest)
        self.sequential.add_module(
            "convlstm1", ConvLSTM(
                in_channels=num_channels, out_channels=num_kernels,
                kernel_size=kernel_size, padding=padding, 
                activation=activation, frame_size=frame_size)
        )

        self.sequential.add_module(
            "batchnorm1", nn.BatchNorm3d(num_features=num_kernels)
        ) 

        # Add rest of the layers
        for l in range(2, num_layers+1):

            self.sequential.add_module(
                f"convlstm{l}", ConvLSTM(
                    in_channels=num_kernels, out_channels=num_kernels,
                    kernel_size=kernel_size, padding=padding, 
                    activation=activation, frame_size=frame_size)
                )
                
            self.sequential.add_module(
                f"batchnorm{l}", nn.BatchNorm3d(num_features=num_kernels)
                ) 

        # Add Convolutional Layer to predict output frame
        self.conv = nn.Conv2d(
            in_channels=num_kernels, out_channels=num_channels,
            kernel_size=kernel_size, padding=padding)

    def forward(self, X):

        # Forward propagation through all the layers
        output = self.sequential(X)

        # Return only the last output frame
        output = self.conv(output[:,:,-1])
        
        return nn.Sigmoid()(output)

In [None]:
def rgb2gray(rgb):
    return np.dot(rgb[...,:3], [0.299, 0.587, 0.144])

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')

In [None]:
device

In [None]:
WalkingDataset = np.load('train_on_beach.npy')
print(WalkingDataset.shape)

In [None]:
np.random.shuffle(WalkingDataset)

In [None]:
dataset = WalkingDataset[:,5:100]

In [None]:
# The input video frames are grayscale, thus single channel
model = Seq2Seq(num_channels = 1, num_kernels = 64, 
                kernel_size = (3, 3), padding = (1, 1), activation="relu", 
                frame_size = (120, 160), num_layers = 3).to(device)

optim = Adam(model.parameters(), lr=1e-4)

# Binary Cross Entropy, target pixel values either 0 or 1
criterion = nn.BCELoss(reduction = 'sum')

In [None]:
model.load_state_dict(torch.load("/home/staditya/Desktop/Pushkal/IVP Project/ConvLSTM/tmp.pth"))

In [None]:
def collate_test(batch):

    # 45th to 60th frames are target
    target = np.array(batch)[:,-1]                     
    
    # Add channel dim, scale pixels between 0 and 1, send to GPU
    batch = torch.tensor(batch, dtype = torch.float32).unsqueeze(1)          
    batch = batch / 255.0                          
    return batch[:,:,:40].to(device), target

In [None]:
a, b, c, d = [], [], [], []

for data_i in range(20):
    data = dataset[data_i]
    outputs = []
    ground_truths = []

    for i in range(10):
        previous_block = dataset[data_i, i : 40]

        if i > 0:
            new_block = np.concatenate((previous_block, np.array(outputs).squeeze(1)), axis = 0)
        else:
            new_block = previous_block

        target = dataset[data_i, 40 + i]
        ground_truths.append(target)

        new_block = np.concatenate((new_block, [target]), axis = 0)

        test_loader = DataLoader([new_block], batch_size=1, collate_fn=collate_test)
        batch, tgt = next(iter(test_loader))

        inp = batch[:,:, : 40]
        output = (model(inp).squeeze(1).cpu().detach().numpy().astype(float)) * 255.0

        outputs.append(output)
    fps = 5
    tmp = data_i + 1
    ground_truths = np.array(ground_truths, dtype = np.uint8)
    outputs = np.array(outputs, dtype = np.uint8).squeeze()
    
    with io.BytesIO() as gif:
        imageio.mimsave(gif, ground_truths, "GIF", fps = fps)    
        output_gif = gif.getvalue()

    with io.BytesIO() as gif:
        imageio.mimsave(gif, outputs, "GIF", fps = fps)    
        target_gif = gif.getvalue()

    print(f"\nTest video: {tmp}")
    display(HBox([widgets.Image(value=output_gif), 
                  widgets.Image(value=target_gif)]))
    if data_i % 2 == 0:
        a = ground_truths
        b = outputs
    else:
        c = ground_truths
        d = outputs

In [None]:
np.save('gt1.npy', a)
np.save('gt2.npy', c)
np.save('compound_effect_1.npy', b)
np.save('compound_effect_2.npy', d)