### Resources

This notebook is heavily based on ConvLSTM implementation and training from [this repo](https://github.com/sladewinter/ConvLSTM/tree/master)

Note: Model training in this notebook takes more than an hour

### Download Data

In [1]:
! wget -q https://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy

In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader

import io
import imageio
from ipywidgets import widgets, HBox

from tqdm import tqdm

In [2]:
# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Import Data and Create Dataloaders

In [4]:
# Load Data as Numpy Array
MovingMNIST = np.load('mnist_test_seq.npy').transpose(1, 0, 2, 3)

# Shuffle Data
np.random.shuffle(MovingMNIST)

# Train, Test, Validation splits
train_data = MovingMNIST[:8000]
val_data = MovingMNIST[8000:9000]
test_data = MovingMNIST[9000:10000]

def collate(batch):

    # Add channel dim, scale pixels between 0 and 1, send to GPU
    batch = torch.tensor(np.array(batch)).unsqueeze(1)
    batch = batch / 255.0
    batch = batch.to(device)

    # Randomly pick 10 frames as input, 11th frame is target
    rand = np.random.randint(10,20)
    return batch[:,:,rand-10:rand], batch[:,:,rand]


# Training Data Loader
train_loader = DataLoader(train_data, shuffle=True,
                        batch_size=16, collate_fn=collate)

# Validation Data Loader
val_loader = DataLoader(val_data, shuffle=True,
                        batch_size=16, collate_fn=collate)

### Visualize Data

In [5]:
# Get a batch
input, _ = next(iter(val_loader))

# Reverse process before displaying
input = input.cpu().numpy() * 255.0

for video in input.squeeze(1)[:3]:          # Loop over videos
    with io.BytesIO() as gif:
        imageio.mimsave(gif,video.astype(np.uint8),"GIF",fps=5)
        display(HBox([widgets.Image(value=gif.getvalue())]))

HBox(children=(Image(value=b'GIF89a@\x00@\x00\x86\x00\x00\x00\x00\x00\x04\x04\x04\x08\x08\x08\n\n\n\x0b\x0b\x0…

HBox(children=(Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\x00\x00\x00\x01\x01\x01\x02\x02\x02\x03\x03\x03\x04\…

HBox(children=(Image(value=b'GIF89a@\x00@\x00\x86\x00\x00\x00\x00\x00\x02\x02\x02\x04\x04\x04\x07\x07\x07\t\t\…

### Defining the Model

In [6]:
# Original ConvLSTM cell as proposed by Shi et al.
class ConvLSTMCell(nn.Module):

    def __init__(self, in_channels, out_channels,
    kernel_size, padding, activation, frame_size):

        super(ConvLSTMCell, self).__init__()

        if activation == "tanh":
            self.activation = torch.tanh
        elif activation == "relu":
            self.activation = torch.relu

        # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
        self.conv = nn.Conv2d(
            in_channels=in_channels + out_channels,
            out_channels=4 * out_channels,
            kernel_size=kernel_size,
            padding=padding)

        # Initialize weights for Hadamard Products
        self.W_ci = nn.Parameter(torch.Tensor(out_channels, *frame_size))
        self.W_co = nn.Parameter(torch.Tensor(out_channels, *frame_size))
        self.W_cf = nn.Parameter(torch.Tensor(out_channels, *frame_size))

    def forward(self, X, H_prev, C_prev):

        # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
        conv_output = self.conv(torch.cat([X, H_prev], dim=1))

        # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
        i_conv, f_conv, C_conv, o_conv = torch.chunk(conv_output, chunks=4, dim=1)

        input_gate = torch.sigmoid(i_conv + self.W_ci * C_prev )
        forget_gate = torch.sigmoid(f_conv + self.W_cf * C_prev )

        # Current Cell output
        C = forget_gate*C_prev + input_gate * self.activation(C_conv)

        output_gate = torch.sigmoid(o_conv + self.W_co * C )

        # Current Hidden State
        H = output_gate * self.activation(C)

        return H, C


class ConvLSTM(nn.Module):

    def __init__(self, in_channels, out_channels,
    kernel_size, padding, activation, frame_size):

        super(ConvLSTM, self).__init__()

        self.out_channels = out_channels

        # We will unroll this over time steps
        self.convLSTMcell = ConvLSTMCell(in_channels, out_channels,
        kernel_size, padding, activation, frame_size)

    def forward(self, X):

        # X is a frame sequence (batch_size, num_channels, seq_len, height, width)

        # Get the dimensions
        batch_size, _, seq_len, height, width = X.size()

        # Initialize output
        output = torch.zeros(batch_size, self.out_channels, seq_len,
        height, width, device=device)

        # Initialize Hidden State
        H = torch.zeros(batch_size, self.out_channels,
        height, width, device=device)

        # Initialize Cell Input
        C = torch.zeros(batch_size,self.out_channels,
        height, width, device=device)

        # Unroll over time steps
        for time_step in range(seq_len):

            H, C = self.convLSTMcell(X[:,:,time_step], H, C)

            output[:,:,time_step] = H

        return output


class Seq2Seq(nn.Module):

    def __init__(self, num_channels, num_kernels, kernel_size, padding,
    activation, frame_size, num_layers):

        super(Seq2Seq, self).__init__()

        self.sequential = nn.Sequential()

        # Add First layer (Different in_channels than the rest)
        self.sequential.add_module(
            "convlstm1", ConvLSTM(
                in_channels=num_channels, out_channels=num_kernels,
                kernel_size=kernel_size, padding=padding,
                activation=activation, frame_size=frame_size)
        )

        self.sequential.add_module(
            "batchnorm1", nn.BatchNorm3d(num_features=num_kernels)
        )

        # Add rest of the layers
        for l in range(2, num_layers+1):

            self.sequential.add_module(
                f"convlstm{l}", ConvLSTM(
                    in_channels=num_kernels, out_channels=num_kernels,
                    kernel_size=kernel_size, padding=padding,
                    activation=activation, frame_size=frame_size)
                )

            self.sequential.add_module(
                f"batchnorm{l}", nn.BatchNorm3d(num_features=num_kernels)
                )

        # Add Convolutional Layer to predict output frame
        self.conv = nn.Conv2d(
            in_channels=num_kernels, out_channels=num_channels,
            kernel_size=kernel_size, padding=padding)

    def forward(self, X):

        # Forward propagation through all the layers
        output = self.sequential(X)

        # Return only the last output frame
        output = self.conv(output[:,:,-1])

        return nn.Sigmoid()(output)

### Instantiate Model, Optimizer and Loss

In [7]:
# The input video frames are grayscale, thus single channel
model = Seq2Seq(num_channels=1, num_kernels=64,
kernel_size=(3, 3), padding=(1, 1), activation="relu",
frame_size=(64, 64), num_layers=3).to(device)

optim = Adam(model.parameters(), lr=1e-4)

# Binary Cross Entropy, target pixel values either 0 or 1
criterion = nn.BCELoss(reduction='sum')

### Train for about 20 epochs

In [8]:
num_epochs = 20

In [9]:
for epoch in range(1, num_epochs+1):

    train_loss = 0
    model.train()
    for batch_num, (input, target) in enumerate(tqdm(train_loader), 1):
        output = model(input)
        loss = criterion(output.flatten(), target.flatten())
        loss.backward()
        optim.step()
        optim.zero_grad()
        train_loss += loss.item()
    train_loss /= len(train_loader.dataset)

    val_loss = 0
    model.eval()
    with torch.no_grad():
        for input, target in val_loader:
            output = model(input)
            loss = criterion(output.flatten(), target.flatten())
            val_loss += loss.item()
    val_loss /= len(val_loader.dataset)

    print("Epoch:{} Training Loss:{:.2f} Validation Loss:{:.2f}\n".format(
        epoch, train_loss, val_loss))

100%|██████████| 500/500 [04:32<00:00,  1.83it/s]


Epoch:1 Training Loss:362.47 Validation Loss:305.33



100%|██████████| 500/500 [04:35<00:00,  1.82it/s]


Epoch:2 Training Loss:299.84 Validation Loss:291.49



100%|██████████| 500/500 [07:35<00:00,  1.10it/s] 


Epoch:3 Training Loss:288.87 Validation Loss:280.83



100%|██████████| 500/500 [04:32<00:00,  1.83it/s]


Epoch:4 Training Loss:281.69 Validation Loss:280.11



100%|██████████| 500/500 [04:32<00:00,  1.83it/s]


Epoch:5 Training Loss:275.99 Validation Loss:269.85



100%|██████████| 500/500 [04:33<00:00,  1.83it/s]


Epoch:6 Training Loss:269.76 Validation Loss:268.68



100%|██████████| 500/500 [04:33<00:00,  1.83it/s]


Epoch:7 Training Loss:264.76 Validation Loss:258.19



100%|██████████| 500/500 [04:33<00:00,  1.83it/s]


Epoch:8 Training Loss:260.27 Validation Loss:256.48



100%|██████████| 500/500 [04:33<00:00,  1.83it/s]


Epoch:9 Training Loss:256.00 Validation Loss:254.01



100%|██████████| 500/500 [04:33<00:00,  1.83it/s]


Epoch:10 Training Loss:253.67 Validation Loss:251.39



100%|██████████| 500/500 [04:35<00:00,  1.81it/s]


Epoch:11 Training Loss:250.87 Validation Loss:247.97



100%|██████████| 500/500 [04:32<00:00,  1.83it/s]


Epoch:12 Training Loss:248.19 Validation Loss:244.99



100%|██████████| 500/500 [04:33<00:00,  1.83it/s]


Epoch:13 Training Loss:245.73 Validation Loss:242.14



100%|██████████| 500/500 [04:33<00:00,  1.83it/s]


Epoch:14 Training Loss:243.97 Validation Loss:243.45



100%|██████████| 500/500 [04:33<00:00,  1.83it/s]


Epoch:15 Training Loss:241.11 Validation Loss:236.56



100%|██████████| 500/500 [04:33<00:00,  1.83it/s]


Epoch:16 Training Loss:239.81 Validation Loss:235.33



100%|██████████| 500/500 [04:33<00:00,  1.83it/s]


Epoch:17 Training Loss:237.71 Validation Loss:235.32



100%|██████████| 500/500 [04:32<00:00,  1.83it/s]


Epoch:18 Training Loss:235.94 Validation Loss:231.50



100%|██████████| 500/500 [04:33<00:00,  1.83it/s]


Epoch:19 Training Loss:235.72 Validation Loss:230.30



100%|██████████| 500/500 [04:33<00:00,  1.83it/s]


Epoch:20 Training Loss:233.44 Validation Loss:232.82



### Visualize what our model has learned so far

In [10]:
def collate_test(batch):

    # Last 10 frames are target
    target = np.array(batch)[:,10:]

    # Add channel dim, scale pixels between 0 and 1, send to GPU
    batch = torch.tensor(batch).unsqueeze(1)
    batch = batch / 255.0
    batch = batch.to(device)
    return batch, target

# Test Data Loader
test_loader = DataLoader(test_data,shuffle=True,
                         batch_size=3, collate_fn=collate_test)

# Get a batch
batch, target = next(iter(test_loader))

# Initialize output sequence
output = np.zeros(target.shape, dtype=np.uint8)

# Loop over timesteps
for timestep in range(target.shape[1]):
  input = batch[:,:,timestep:timestep+10]
  output[:,timestep]=(model(input).squeeze(1).cpu()>0.5)*255.0

  batch = torch.tensor(batch).unsqueeze(1)


In [11]:
for tgt, out in zip(target, output):       # Loop over samples

    # Write target video as gif
    with io.BytesIO() as gif:
        imageio.mimsave(gif, tgt, "GIF", fps = 5)
        target_gif = gif.getvalue()

    # Write output video as gif
    with io.BytesIO() as gif:
        imageio.mimsave(gif, out, "GIF", fps = 5)
        output_gif = gif.getvalue()

    display(HBox([widgets.Image(value=target_gif),
                  widgets.Image(value=output_gif)]))

HBox(children=(Image(value=b'GIF89a@\x00@\x00\x86\x00\x00\x00\x00\x00\x01\x01\x01\x02\x02\x02\x03\x03\x03\x04\…

HBox(children=(Image(value=b'GIF89a@\x00@\x00\x86\x00\x00\x00\x00\x00\x01\x01\x01\x03\x03\x03\x06\x06\x06\n\n\…

HBox(children=(Image(value=b'GIF89a@\x00@\x00\x86\x00\x00\x00\x00\x00\x02\x02\x02\x04\x04\x04\x05\x05\x05\x07\…