In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader

import io
import imageio
from ipywidgets import widgets, HBox

# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
class ConvLSTMCell(nn.Module):

    def __init__(self, in_channels, out_channels, 
    kernel_size, padding, activation, frame_size):

        super(ConvLSTMCell, self).__init__()  

        if activation == "tanh":
            self.activation = torch.tanh 
        elif activation == "relu":
            self.activation = torch.relu
        
        # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
        self.conv = nn.Conv2d(
            in_channels=in_channels + out_channels, 
            out_channels=4 * out_channels, 
            kernel_size=kernel_size, 
            padding=padding)           

        # Initialize weights for Hadamard Products
        self.W_ci = nn.Parameter(torch.Tensor(out_channels, *frame_size))
        self.W_co = nn.Parameter(torch.Tensor(out_channels, *frame_size))
        self.W_cf = nn.Parameter(torch.Tensor(out_channels, *frame_size))

    def forward(self, X, H_prev, C_prev):

        # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
        conv_output = self.conv(torch.cat([X, H_prev], dim=1))

        # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
        i_conv, f_conv, C_conv, o_conv = torch.chunk(conv_output, chunks=4, dim=1)

        input_gate = torch.sigmoid(i_conv + self.W_ci * C_prev )
        forget_gate = torch.sigmoid(f_conv + self.W_cf * C_prev )

        # Current Cell output
        C = forget_gate*C_prev + input_gate * self.activation(C_conv)

        output_gate = torch.sigmoid(o_conv + self.W_co * C )

        # Current Hidden State
        H = output_gate * self.activation(C)

        return H, C

In [5]:
class ConvLSTM(nn.Module):

    def __init__(self, in_channels, out_channels, 
    kernel_size, padding, activation, frame_size):

        super(ConvLSTM, self).__init__()

        self.out_channels = out_channels

        # We will unroll this over time steps
        self.convLSTMcell = ConvLSTMCell(in_channels, out_channels, 
        kernel_size, padding, activation, frame_size)

    def forward(self, X):

        # X is a frame sequence (batch_size, num_channels, seq_len, height, width)

        # Get the dimensions
        batch_size, _, seq_len, height, width = X.size()

        # Initialize output
        output = torch.zeros(batch_size, self.out_channels, seq_len, 
        height, width, device=device)
        
        # Initialize Hidden State
        H = torch.zeros(batch_size, self.out_channels, 
        height, width, device=device)

        # Initialize Cell Input
        C = torch.zeros(batch_size,self.out_channels, 
        height, width, device=device)

        # Unroll over time steps
        for time_step in range(seq_len):

            H, C = self.convLSTMcell(X[:,:,time_step], H, C)

            output[:,:,time_step] = H

        return output

In [6]:
class Seq2Seq(nn.Module):

    def __init__(self, num_channels, num_kernels, kernel_size, padding, 
    activation, frame_size, num_layers):

        super(Seq2Seq, self).__init__()

        self.sequential = nn.Sequential()

        # Add First layer (Different in_channels than the rest)
        self.sequential.add_module(
            "convlstm1", ConvLSTM(
                in_channels=num_channels, out_channels=num_kernels,
                kernel_size=kernel_size, padding=padding, 
                activation=activation, frame_size=frame_size)
        )

        self.sequential.add_module(
            "batchnorm1", nn.BatchNorm3d(num_features=num_kernels)
        ) 

        # Add rest of the layers
        for l in range(2, num_layers+1):

            self.sequential.add_module(
                f"convlstm{l}", ConvLSTM(
                    in_channels=num_kernels, out_channels=num_kernels,
                    kernel_size=kernel_size, padding=padding, 
                    activation=activation, frame_size=frame_size)
                )
                
            self.sequential.add_module(
                f"batchnorm{l}", nn.BatchNorm3d(num_features=num_kernels)
                ) 

        # Add Convolutional Layer to predict output frame
        self.conv = nn.Conv2d(
            in_channels=num_kernels, out_channels=num_channels,
            kernel_size=kernel_size, padding=padding)

    def forward(self, X):

        # Forward propagation through all the layers
        output = self.sequential(X)

        # Return only the last output frame
        output = self.conv(output[:,:,-1])
        
        return nn.Sigmoid()(output)

DataLoaders


In [9]:
# Load Data as Numpy Array
#!wget http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy
MovingMNIST = np.load('mnist_test_seq.npy').transpose(1, 0, 2, 3)

# Shuffle Data
np.random.shuffle(MovingMNIST)

# Train, Test, Validation splits
train_data = MovingMNIST[:8000]         
val_data = MovingMNIST[8000:9000]       
test_data = MovingMNIST[9000:10000]     

def collate(batch):

    # Add channel dim, scale pixels between 0 and 1, send to GPU
    batch = torch.tensor(batch).unsqueeze(1)     
    batch = batch / 255.0                        
    batch = batch.to(device)                     

    # Randomly pick 10 frames as input, 11th frame is target
    rand = np.random.randint(10,20)                     
    return batch[:,:,rand-10:rand], batch[:,:,rand]     


# Training Data Loader
train_loader = DataLoader(train_data, shuffle=True, 
                        batch_size=16, collate_fn=collate)

# Validation Data Loader
val_loader = DataLoader(val_data, shuffle=True, 
                        batch_size=16, collate_fn=collate)

Visualize the data

In [10]:
# Get a batch
input, _ = next(iter(val_loader))

# Reverse process before displaying
input = input.cpu().numpy() * 255.0     

for video in input.squeeze(1)[:3]:          # Loop over videos
    with io.BytesIO() as gif:
        imageio.mimsave(gif,video.astype(np.uint8),"GIF",fps=240)
        display(HBox([widgets.Image(value=gif.getvalue())]))

  batch = torch.tensor(batch).unsqueeze(1)


HBox(children=(Image(value=b'GIF89a@\x00@\x00\x86\x00\x00\x00\x00\x00\x01\x01\x01\x02\x02\x02\x03\x03\x03\x04\…

HBox(children=(Image(value=b'GIF89a@\x00@\x00\x86\x00\x00\x00\x00\x00\x07\x07\x07\x08\x08\x08\n\n\n\x0b\x0b\x0…

HBox(children=(Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\x00\x00\x00\x01\x01\x01\x03\x03\x03\x04\x04\x04\x05\…

Instantiating the model, optimizer, and loss function

In [11]:

# The input video frames are grayscale, thus single channel
model = Seq2Seq(num_channels=1, num_kernels=64, 
kernel_size=(3, 3), padding=(1, 1), activation="relu", 
frame_size=(64, 64), num_layers=3).to(device)

optim = Adam(model.parameters(), lr=1e-4)

# Binary Cross Entropy, target pixel values either 0 or 1
criterion = nn.BCEWithLogitsLoss(reduction='sum')

In [12]:
import torch
from tqdm import tqdm

num_epochs = 50
scaler = torch.cuda.amp.GradScaler()

for epoch in range(1, num_epochs+1):
    
    train_loss = 0                                                 
    model.train()                                                  
    for batch_num, (input, target) in enumerate(tqdm(train_loader), 1):  
        optim.zero_grad()  
        with torch.cuda.amp.autocast():                                 
            output = model(input)                                     
            loss = criterion(output.flatten(), target.flatten())       
        scaler.scale(loss).backward()  
        scaler.step(optim)
        scaler.update()                                      
        train_loss += loss.item()                                 
    train_loss /= len(train_loader.dataset)                       

    val_loss = 0                                                 
    model.eval()                                                   
    with torch.no_grad():                                          
        for input, target in tqdm(val_loader):                          
            output = model(input)                                   
            loss = criterion(output.flatten(), target.flatten())   
            val_loss += loss.item()                                
    val_loss /= len(val_loader.dataset)                            

    print("Epoch:{} Training Loss:{:.2f} Validation Loss:{:.2f}\n".format(
        epoch, train_loss, val_loss))

100%|██████████| 500/500 [06:55<00:00,  1.20it/s]
100%|██████████| 63/63 [00:19<00:00,  3.15it/s]


Epoch:1 Training Loss:2895.52 Validation Loss:2839.19



100%|██████████| 500/500 [06:21<00:00,  1.31it/s]
100%|██████████| 63/63 [00:17<00:00,  3.56it/s]


Epoch:2 Training Loss:2839.16 Validation Loss:2839.15



100%|██████████| 500/500 [05:57<00:00,  1.40it/s]
100%|██████████| 63/63 [00:19<00:00,  3.20it/s]


Epoch:3 Training Loss:2839.15 Validation Loss:2839.15



100%|██████████| 500/500 [06:23<00:00,  1.30it/s]
100%|██████████| 63/63 [00:19<00:00,  3.19it/s]


Epoch:4 Training Loss:2839.15 Validation Loss:2839.15



100%|██████████| 500/500 [06:19<00:00,  1.32it/s]
100%|██████████| 63/63 [00:18<00:00,  3.47it/s]


Epoch:5 Training Loss:2839.14 Validation Loss:2839.14



100%|██████████| 500/500 [06:14<00:00,  1.34it/s]
100%|██████████| 63/63 [00:17<00:00,  3.60it/s]


Epoch:6 Training Loss:2839.14 Validation Loss:2839.14



100%|██████████| 500/500 [05:57<00:00,  1.40it/s]
100%|██████████| 63/63 [00:17<00:00,  3.60it/s]


Epoch:7 Training Loss:2839.14 Validation Loss:2839.14



100%|██████████| 500/500 [06:11<00:00,  1.35it/s]
100%|██████████| 63/63 [00:18<00:00,  3.33it/s]


Epoch:8 Training Loss:2839.14 Validation Loss:2839.13



100%|██████████| 500/500 [06:34<00:00,  1.27it/s]
100%|██████████| 63/63 [00:19<00:00,  3.26it/s]


Epoch:9 Training Loss:2839.13 Validation Loss:2839.13



100%|██████████| 500/500 [06:24<00:00,  1.30it/s]
100%|██████████| 63/63 [00:20<00:00,  3.08it/s]


Epoch:10 Training Loss:2839.13 Validation Loss:2839.13



100%|██████████| 500/500 [06:52<00:00,  1.21it/s]
100%|██████████| 63/63 [00:19<00:00,  3.18it/s]


Epoch:11 Training Loss:2839.13 Validation Loss:2839.13



100%|██████████| 500/500 [06:54<00:00,  1.21it/s]
100%|██████████| 63/63 [00:19<00:00,  3.18it/s]


Epoch:12 Training Loss:2838.17 Validation Loss:2830.69



100%|██████████| 500/500 [06:12<00:00,  1.34it/s]
100%|██████████| 63/63 [00:17<00:00,  3.58it/s]


Epoch:13 Training Loss:2825.17 Validation Loss:2824.20



100%|██████████| 500/500 [05:55<00:00,  1.41it/s]
100%|██████████| 63/63 [00:17<00:00,  3.58it/s]


Epoch:14 Training Loss:2821.64 Validation Loss:2821.20



100%|██████████| 500/500 [06:01<00:00,  1.38it/s]
100%|██████████| 63/63 [00:17<00:00,  3.59it/s]


Epoch:15 Training Loss:2819.64 Validation Loss:2819.65



100%|██████████| 500/500 [05:56<00:00,  1.40it/s]
100%|██████████| 63/63 [00:17<00:00,  3.70it/s]


Epoch:16 Training Loss:2818.34 Validation Loss:2818.05



100%|██████████| 500/500 [05:43<00:00,  1.45it/s]
100%|██████████| 63/63 [00:17<00:00,  3.70it/s]


Epoch:17 Training Loss:2817.11 Validation Loss:2818.14



100%|██████████| 500/500 [05:43<00:00,  1.45it/s]
100%|██████████| 63/63 [00:17<00:00,  3.70it/s]


Epoch:18 Training Loss:2816.42 Validation Loss:2816.57



100%|██████████| 500/500 [05:43<00:00,  1.45it/s]
100%|██████████| 63/63 [00:17<00:00,  3.68it/s]


Epoch:19 Training Loss:2815.70 Validation Loss:2815.87



100%|██████████| 500/500 [05:42<00:00,  1.46it/s]
100%|██████████| 63/63 [00:16<00:00,  3.73it/s]


Epoch:20 Training Loss:2815.32 Validation Loss:2814.79



100%|██████████| 500/500 [05:42<00:00,  1.46it/s]
100%|██████████| 63/63 [00:17<00:00,  3.70it/s]


Epoch:21 Training Loss:2814.21 Validation Loss:2814.73



100%|██████████| 500/500 [05:42<00:00,  1.46it/s]
100%|██████████| 63/63 [00:16<00:00,  3.74it/s]


Epoch:22 Training Loss:2813.77 Validation Loss:2814.33



100%|██████████| 500/500 [05:43<00:00,  1.46it/s]
100%|██████████| 63/63 [00:17<00:00,  3.70it/s]


Epoch:23 Training Loss:2813.33 Validation Loss:2813.99



100%|██████████| 500/500 [05:43<00:00,  1.46it/s]
100%|██████████| 63/63 [00:16<00:00,  3.75it/s]


Epoch:24 Training Loss:2812.55 Validation Loss:2813.86



100%|██████████| 500/500 [05:42<00:00,  1.46it/s]
100%|██████████| 63/63 [00:16<00:00,  3.72it/s]


Epoch:25 Training Loss:2812.45 Validation Loss:2813.04



100%|██████████| 500/500 [05:43<00:00,  1.45it/s]
100%|██████████| 63/63 [00:16<00:00,  3.72it/s]


Epoch:26 Training Loss:2812.17 Validation Loss:2812.20



100%|██████████| 500/500 [05:43<00:00,  1.46it/s]
100%|██████████| 63/63 [00:16<00:00,  3.73it/s]


Epoch:27 Training Loss:2811.65 Validation Loss:2812.65



100%|██████████| 500/500 [05:43<00:00,  1.46it/s]
100%|██████████| 63/63 [00:16<00:00,  3.75it/s]


Epoch:28 Training Loss:2811.18 Validation Loss:2811.87



100%|██████████| 500/500 [05:43<00:00,  1.46it/s]
100%|██████████| 63/63 [00:16<00:00,  3.72it/s]


Epoch:29 Training Loss:2810.84 Validation Loss:2812.40



100%|██████████| 500/500 [05:42<00:00,  1.46it/s]
100%|██████████| 63/63 [00:16<00:00,  3.72it/s]


Epoch:30 Training Loss:2810.94 Validation Loss:2811.18



100%|██████████| 500/500 [05:43<00:00,  1.46it/s]
100%|██████████| 63/63 [00:16<00:00,  3.71it/s]


Epoch:31 Training Loss:2810.43 Validation Loss:2812.02



100%|██████████| 500/500 [05:50<00:00,  1.43it/s]
100%|██████████| 63/63 [00:17<00:00,  3.65it/s]


Epoch:32 Training Loss:2810.24 Validation Loss:2811.70



100%|██████████| 500/500 [06:09<00:00,  1.35it/s]
100%|██████████| 63/63 [00:17<00:00,  3.51it/s]


Epoch:33 Training Loss:2810.11 Validation Loss:2811.36



100%|██████████| 500/500 [09:43<00:00,  1.17s/it]
100%|██████████| 63/63 [00:31<00:00,  1.98it/s]


Epoch:34 Training Loss:2809.89 Validation Loss:2811.27



100%|██████████| 500/500 [07:28<00:00,  1.12it/s]
100%|██████████| 63/63 [00:19<00:00,  3.26it/s]


Epoch:35 Training Loss:2809.51 Validation Loss:2810.31



100%|██████████| 500/500 [06:13<00:00,  1.34it/s]
100%|██████████| 63/63 [00:17<00:00,  3.62it/s]


Epoch:36 Training Loss:2809.48 Validation Loss:2810.33



100%|██████████| 500/500 [06:32<00:00,  1.27it/s]
100%|██████████| 63/63 [00:19<00:00,  3.20it/s]


Epoch:37 Training Loss:2809.42 Validation Loss:2810.25



100%|██████████| 500/500 [06:48<00:00,  1.22it/s]
100%|██████████| 63/63 [00:19<00:00,  3.23it/s]


Epoch:38 Training Loss:2808.92 Validation Loss:2810.60



100%|██████████| 500/500 [07:05<00:00,  1.17it/s]
100%|██████████| 63/63 [00:18<00:00,  3.34it/s]


Epoch:39 Training Loss:2808.90 Validation Loss:2810.11



100%|██████████| 500/500 [07:07<00:00,  1.17it/s]
100%|██████████| 63/63 [00:17<00:00,  3.61it/s]


Epoch:40 Training Loss:2808.72 Validation Loss:2809.46



100%|██████████| 500/500 [06:29<00:00,  1.28it/s]
100%|██████████| 63/63 [00:18<00:00,  3.46it/s]


Epoch:41 Training Loss:2808.58 Validation Loss:2810.08



100%|██████████| 500/500 [06:06<00:00,  1.36it/s]
100%|██████████| 63/63 [00:18<00:00,  3.47it/s]


Epoch:42 Training Loss:2808.66 Validation Loss:2809.65



100%|██████████| 500/500 [06:17<00:00,  1.32it/s]
100%|██████████| 63/63 [00:18<00:00,  3.48it/s]


Epoch:43 Training Loss:2808.26 Validation Loss:2809.50



100%|██████████| 500/500 [06:15<00:00,  1.33it/s]
100%|██████████| 63/63 [00:18<00:00,  3.48it/s]


Epoch:44 Training Loss:2808.31 Validation Loss:2808.85



100%|██████████| 500/500 [06:07<00:00,  1.36it/s]
100%|██████████| 63/63 [00:17<00:00,  3.68it/s]


Epoch:45 Training Loss:2808.09 Validation Loss:2809.27



100%|██████████| 500/500 [05:52<00:00,  1.42it/s]
100%|██████████| 63/63 [00:17<00:00,  3.64it/s]


Epoch:46 Training Loss:2807.89 Validation Loss:2809.05



100%|██████████| 500/500 [05:52<00:00,  1.42it/s]
100%|██████████| 63/63 [00:17<00:00,  3.60it/s]


Epoch:47 Training Loss:2807.79 Validation Loss:2809.03



100%|██████████| 500/500 [05:52<00:00,  1.42it/s]
100%|██████████| 63/63 [00:17<00:00,  3.63it/s]


Epoch:48 Training Loss:2807.71 Validation Loss:2809.04



100%|██████████| 500/500 [05:52<00:00,  1.42it/s]
100%|██████████| 63/63 [00:17<00:00,  3.67it/s]


Epoch:49 Training Loss:2807.72 Validation Loss:2808.47



100%|██████████| 500/500 [05:52<00:00,  1.42it/s]
100%|██████████| 63/63 [00:17<00:00,  3.68it/s]

Epoch:50 Training Loss:2807.18 Validation Loss:2809.23






Visualizing the model learnings

In [13]:
def collate_test(batch):

    # Last 10 frames are target
    target = np.array(batch)[:,10:]                     
    
    # Add channel dim, scale pixels between 0 and 1, send to GPU
    batch = torch.tensor(batch).unsqueeze(1)          
    batch = batch / 255.0                             
    batch = batch.to(device)                          
    return batch, target

# Test Data Loader
test_loader = DataLoader(test_data,shuffle=True, 
                         batch_size=3, collate_fn=collate_test)

# Get a batch
batch, target = next(iter(test_loader))

# Initialize output sequence
output = np.zeros(target.shape, dtype=np.uint8)

# Loop over timesteps
for timestep in range(target.shape[1]):
  input = batch[:,:,timestep:timestep+10]   
  output[:,timestep]=(model(input).squeeze(1).cpu()>0.5)*255.0

In [20]:

for tgt, out in zip(target, output):       # Loop over samples
    
    # Write target video as gif
    with io.BytesIO() as gif:
        imageio.mimsave(gif, tgt, "GIF", fps = 60)    
        target_gif = gif.getvalue()

    # Write output video as gif
    with io.BytesIO() as gif:
        imageio.mimsave(gif, out, "GIF", fps = 60)    
        output_gif = gif.getvalue()

    display(HBox([widgets.Image(value=target_gif), 
                  widgets.Image(value=output_gif)]))

HBox(children=(Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\x00\x00\x00\x01\x01\x01\x02\x02\x02\x03\x03\x03\x04\…

HBox(children=(Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\x00\x00\x00\x01\x01\x01\x02\x02\x02\x03\x03\x03\x05\…

HBox(children=(Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\x00\x00\x00\x01\x01\x01\x02\x02\x02\x03\x03\x03\x04\…