In [None]:
!which python3

In [None]:
import cv2
import torchinfo
import numpy as np
from torch.optim import Adam
from torch.utils.data import DataLoader

In [None]:
import io
import imageio
from tqdm import tqdm
from avi_r import AVIReader
import matplotlib.pyplot as plt
from ipywidgets import widgets, HBox

In [None]:
import os
import torch
import torch.nn as nn

In [None]:
class ConvLSTMCell(nn.Module):

    def __init__(self, in_channels, out_channels, 
    kernel_size, padding, activation, frame_size):

        super(ConvLSTMCell, self).__init__()  

        if activation == "tanh":
            self.activation = torch.tanh 
        elif activation == "relu":
            self.activation = torch.relu
        
        # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
        self.conv = nn.Conv2d(
            in_channels=in_channels + out_channels, 
            out_channels=4 * out_channels, 
            kernel_size=kernel_size, 
            padding=padding)           

        # Initialize weights for Hadamard Products
        self.W_ci = nn.Parameter(torch.Tensor(out_channels, *frame_size))
        self.W_co = nn.Parameter(torch.Tensor(out_channels, *frame_size))
        self.W_cf = nn.Parameter(torch.Tensor(out_channels, *frame_size))

    def forward(self, X, H_prev, C_prev):

        # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
        conv_output = self.conv(torch.cat([X, H_prev], dim=1))

        # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
        i_conv, f_conv, C_conv, o_conv = torch.chunk(conv_output, chunks=4, dim=1)

        input_gate = torch.sigmoid(i_conv + self.W_ci * C_prev )
        forget_gate = torch.sigmoid(f_conv + self.W_cf * C_prev )

        # Current Cell output
        C = forget_gate*C_prev + input_gate * self.activation(C_conv)

        output_gate = torch.sigmoid(o_conv + self.W_co * C )

        # Current Hidden State
        H = output_gate * self.activation(C)

        return H, C

In [None]:
class ConvLSTM(nn.Module):

    def __init__(self, in_channels, out_channels, 
    kernel_size, padding, activation, frame_size):

        super(ConvLSTM, self).__init__()

        self.out_channels = out_channels

        # We will unroll this over time steps
        self.convLSTMcell = ConvLSTMCell(in_channels, out_channels, 
        kernel_size, padding, activation, frame_size)

    def forward(self, X):

        # X is a frame sequence (batch_size, num_channels, seq_len, height, width)

        # Get the dimensions
        batch_size, _, seq_len, height, width = X.size()

        # Initialize output
        output = torch.zeros(batch_size, self.out_channels, seq_len, 
        height, width, device=device)
        
        # Initialize Hidden State
        H = torch.zeros(batch_size, self.out_channels, 
        height, width, device=device)

        # Initialize Cell Input
        C = torch.zeros(batch_size,self.out_channels, 
        height, width, device=device)

        # Unroll over time steps
        for time_step in range(seq_len):

            H, C = self.convLSTMcell(X[:,:,time_step], H, C)

            output[:,:,time_step] = H

        return output

In [None]:
class Seq2Seq(nn.Module):

    def __init__(self, num_channels, num_kernels, kernel_size, padding, 
    activation, frame_size, num_layers):

        super(Seq2Seq, self).__init__()

        self.sequential = nn.Sequential()

        # Add First layer (Different in_channels than the rest)
        self.sequential.add_module(
            "convlstm1", ConvLSTM(
                in_channels=num_channels, out_channels=num_kernels,
                kernel_size=kernel_size, padding=padding, 
                activation=activation, frame_size=frame_size)
        )

        self.sequential.add_module(
            "batchnorm1", nn.BatchNorm3d(num_features=num_kernels)
        ) 

        # Add rest of the layers
        for l in range(2, num_layers+1):

            self.sequential.add_module(
                f"convlstm{l}", ConvLSTM(
                    in_channels=num_kernels, out_channels=num_kernels,
                    kernel_size=kernel_size, padding=padding, 
                    activation=activation, frame_size=frame_size)
                )
                
            self.sequential.add_module(
                f"batchnorm{l}", nn.BatchNorm3d(num_features=num_kernels)
                ) 

        # Add Convolutional Layer to predict output frame
        self.conv = nn.Conv2d(
            in_channels=num_kernels, out_channels=num_channels,
            kernel_size=kernel_size, padding=padding)

    def forward(self, X):

        # Forward propagation through all the layers
        output = self.sequential(X)

        # Return only the last output frame
        output = self.conv(output[:,:,-1])
        
        return nn.Sigmoid()(output)

In [None]:
def rgb2gray(rgb):
    return np.dot(rgb[...,:3], [0.299, 0.587, 0.144])

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
device

In [None]:
WalkingDataset = np.load('train_on_beach.npy')
print(WalkingDataset.shape)

In [None]:
np.random.shuffle(WalkingDataset)

In [None]:
tr_idx = 0
tr_data = 25
test_data = np.random.randint(0, high = len(WalkingDataset), size = 3)

train_data = WalkingDataset[tr_idx : tr_idx + tr_data, 0:100]
test_data = WalkingDataset[tr_idx + tr_data : , 0:100]

print(f"Train Data shape: {train_data.shape}")
print(f"Test Data shape: {test_data.shape}")

In [None]:
!nvidia-smi

In [None]:
def collate(batch):

    # Add channel dim, scale pixels between 0 and 1, send to GPU
    batch = torch.tensor(batch, dtype = torch.float32).unsqueeze(1)     
    batch = batch / 255.0                                          
    batch = batch.to(device)
    
    return batch[:,:,15:55].to(device), batch[:,:,55].to(device)

In [None]:
# Training Data Loader
train_loader = DataLoader(train_data, shuffle=True, 
                        batch_size=1, collate_fn=collate)

In [None]:
fps = 20

# Get a batch
inp, _ = next(iter(train_loader))

# Reverse process before displaying
inp = inp.cpu().numpy() * 255.0     

for video in inp.squeeze(1)[:]:          # Loop over videos
    with io.BytesIO() as gif:
        imageio.mimsave(gif,video.astype(np.uint8),"GIF",fps=fps)
        display(HBox([widgets.Image(value=gif.getvalue())]))

In [None]:
# The input video frames are grayscale, thus single channel
model = Seq2Seq(num_channels = 1, num_kernels = 64, 
                kernel_size = (3, 3), padding = (1, 1), activation="relu", 
                frame_size = (120, 160), num_layers = 3).to(device)

optim = Adam(model.parameters(), lr=1e-4)

# Binary Cross Entropy, target pixel values either 0 or 1
criterion = nn.BCELoss(reduction = 'sum')

In [None]:
num_epochs = 50
tr_loss = []

for epoch in range(1, num_epochs+1):
    pbar = tqdm(total=len(train_loader), position=0, leave=True, bar_format='{l_bar}{bar:60}{r_bar}{bar:-10b}')
    train_loss = 0                                                 
    model.train()                                                  
    for batch_num, (inp, target) in enumerate(train_loader, 1):  
        output = model(inp)                                     
        loss = criterion(output.flatten(), target.flatten())       
        loss.backward()                                            
        optim.step()                                               
        optim.zero_grad()                                           
        train_loss += loss.item()
        pbar.update(1)
    train_loss /= len(train_loader.dataset)                       

    val_loss = 0                                                         
    print("Epoch:{} Training Loss:{:.2f}\n".format(epoch, train_loss))
    tr_loss.append(train_loss)

In [None]:
plt.figure(figsize = (12, 4))
plt.subplot(1, 2, 1)
plt.plot(1+np.arange(num_epochs), tr_loss)
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")

plt.show()

In [None]:
model_dir = ""
model_path = os.path.join(model_dir, "tmp"+".pth")
torch.save(model.state_dict(), model_path)

In [None]:
# model.load_state_dict(torch.load("/home/staditya/Desktop/Pushkal/IVP Project/ConvLSTM/tmp.pth"))

In [None]:
def collate_test(batch):

    # 50th to 60th frames are target
    target = np.array(batch)[:, 45 : 60]                     
    
    # Add channel dim, scale pixels between 0 and 1, send to GPU
    batch = torch.tensor(batch, dtype = torch.float32).unsqueeze(1)          
    batch = batch / 255.0                          
    return batch[:,:,5:60].to(device), target

In [None]:
# Test Data Loader
test_loader = DataLoader(test_data,shuffle=True, batch_size=1, collate_fn=collate_test)

In [None]:
final_targets = []
final_outputs = []

for i in range(len(test_loader)):
    batch, target = next(iter(test_loader))
    # print(batch.shape)
    print(i + 1, target.shape)
    # Initialize output sequence
    output = np.zeros(target.shape, dtype = np.uint8)

    # Loop over timesteps
    for timestep in range(target.shape[1]):
        inp = batch[:,:,timestep : 40 + timestep]   
        output[:, timestep] = (model(inp).squeeze(1).cpu().detach().numpy().astype(float)) * 255.0
    final_targets.append(target)
    final_outputs.append(output)

In [None]:
fps = 20
tmp = 1
for target, output in zip(final_targets, final_outputs):
    
    # Loop over samples
    target = np.array(target, dtype = 'uint8').squeeze()
    output = np.array(output, dtype = 'uint8').squeeze()

    with io.BytesIO() as gif:
        imageio.mimsave(gif, target, "GIF", fps = fps)    
        target_gif = gif.getvalue()
    
    with io.BytesIO() as gif:
        imageio.mimsave(gif, output, "GIF", fps = fps)    
        output_gif = gif.getvalue()
    
    print(f"\nTest video: {tmp}")
    display(HBox([widgets.Image(value=target_gif), 
                  widgets.Image(value=output_gif)]))

    tmp += 1

In [None]:
test_on_train_loader = DataLoader(train_data,shuffle=True, 
                         batch_size=1, collate_fn=collate_test)

train_final_targets = []
train_final_outputs = []

for i in range(len(test_on_train_loader)):
    batch, target = next(iter(test_on_train_loader))
    # print(batch.shape)
    print(i + 1, target.shape)
    # Initialize output sequence
    output = np.zeros(target.shape, dtype = np.uint8)

    # Loop over timesteps
    for timestep in range(target.shape[1]):
        inp = batch[:,:,timestep : 40 + timestep]   
        output[:, timestep] = (model(inp).squeeze(1).cpu().detach()) * 255.0
    train_final_targets.append(target)
    train_final_outputs.append(output)

In [None]:
tmp = 1
fps = 20
for target, output in zip(train_final_targets, train_final_outputs):
    
    # Loop over samples
    target = np.array(target, dtype = 'uint8').squeeze()
    output = np.array(output, dtype = 'uint8').squeeze()
    
    with io.BytesIO() as gif:
        imageio.mimsave(gif, target, "GIF", fps = fps)    
        target_gif = gif.getvalue()
    
    with io.BytesIO() as gif:
        imageio.mimsave(gif, output, "GIF", fps = fps)    
        output_gif = gif.getvalue()
    
    print(f"\nTest video: {tmp}")
    display(HBox([widgets.Image(value=target_gif), 
                  widgets.Image(value=output_gif)]))
    tmp += 1

In [None]:
# target_array = np.array(train_final_targets+final_targets).squeeze(1)
# output_array = np.array(train_final_outputs+final_outputs).squeeze(1)

# np.save('output_model_1.npy', output_array)
# np.save('target_model_1.npy', target_array)