### Download the dataset

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!wget -q https://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy

In [14]:
%cd "/content/drive/MyDrive/project/codes"

/content/drive/.shortcut-targets-by-id/19aYohncN8veP4Zl7krOEqywsrq2c-WMR/project/codes


In [15]:
%ls

ConvLSTMCell.py  FFC.py      Plots-DL.ipynb  SelfAttention.py
ConvLSTM.py      Network.py  [0m[01;34m__pycache__[0m/    train.ipynb


### Import dependencies

In [16]:
import argparse
import os
import json
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.optim import Adam
from Network import Net
from tqdm import tqdm
from torch.utils.data import DataLoader


import io
import imageio
from ipywidgets import widgets, HBox

# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Load the dataset and create dataloaders

In [17]:
BATCH_SIZE = 8

In [18]:
# Load Data as Numpy Array
MovingMNIST = np.load('/content/mnist_test_seq.npy').transpose(1, 0, 2, 3)

# Shuffle Data
np.random.shuffle(MovingMNIST)

# Train, Test, Validation splits
train_data = MovingMNIST[:8000]         
val_data = MovingMNIST[8000:9000]       
test_data = MovingMNIST[9000:10000]     

def collate(batch):

    # Add channel dim, scale pixels between 0 and 1, send to GPU
    batch = torch.tensor(batch).unsqueeze(1)     
    batch = batch / 255.0                        
    batch = batch.to(device)                     

    # Randomly pick 10 frames as input, 11th frame is target
    rand = np.random.randint(10, 20)                     
    return batch[:, :, rand-10:rand], batch[:, :, rand]     


# Training Data Loader
train_loader = DataLoader(train_data, shuffle=True, batch_size=BATCH_SIZE, collate_fn=collate)

# Validation Data Loader
val_loader = DataLoader(val_data, shuffle=True, batch_size=BATCH_SIZE, collate_fn=collate)

### Visualize the data

In [19]:
# Get a batch
input, _ = next(iter(val_loader))

# Reverse process before displaying
input = input.cpu().numpy() * 255.0     

for video in input.squeeze(1)[:3]:          # Loop over videos
    with io.BytesIO() as gif:
        imageio.mimsave(gif, video.astype(np.uint8), "GIF", fps=5)
        display(HBox([widgets.Image(value=gif.getvalue())]))

  from ipykernel import kernelapp as app


HBox(children=(Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\xff\xff\xff\xfd\xfd\xfd\xfc\xfc\xfc\xf4\xf4\xf4\xf2\…

HBox(children=(Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\xff\xff\xff\xfe\xfe\xfe\xfd\xfd\xfd\xfc\xfc\xfc\xfb\…

HBox(children=(Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\xff\xff\xff\xfe\xfe\xfe\xfd\xfd\xfd\xfc\xfc\xfc\xf8\…

### Instantiate the model, optimizer and loss

In [20]:
FFC = False
ATTENTION = False
NUM_KERNELS = 32
NUM_LAYERS = 2

In [23]:
# The input video frames are grayscale, thus single channel
model = Net(num_channels=1, num_kernels=NUM_KERNELS, kernel_size=(3, 3), padding=(1, 1), activation="relu", 
                frame_size=(64, 64), num_layers=NUM_LAYERS, ffc=FFC, attention=ATTENTION).to(device)
model.load_state_dict(torch.load('/content/drive/MyDrive/project/exps/batch[8]_kernels[32]_layers[2]_attn[False]_ffc[False]/checkpoints/checkpoint_last.pth'), strict=False)
optim = Adam(model.parameters(), lr=1e-4)

# Binary Cross Entropy, target pixel values either 0 or 1
criterion = nn.BCELoss(reduction='sum')

### Train the model

In [None]:
EPOCHS = 1
SAVE_DIR = '/content/drive/MyDrive/Courses/Deep_Learning_Course/project/exps'
EXP_DIR = f'batch[{BATCH_SIZE}]_kernels[{NUM_KERNELS}]_layers[{NUM_LAYERS}]_attn[{ATTENTION}]_ffc[{FFC}]_2'

In [None]:
train_loss_history = []
for epoch in range(EPOCHS):
    with tqdm(total=len(train_data), desc=f"Epoch {epoch + 1}/{EPOCHS}", unit='img') as pbar:
        train_loss = 0                                                 
        model.train()
        for batch_num, (input, target) in enumerate(train_loader, 1):
            input = input.to(device)
            target = target.to(device)
            output = model(input)                                     
            loss = criterion(output.flatten(), target.flatten())       
            loss.backward()                                            
            optim.step()                                               
            optim.zero_grad()                                           
            train_loss += loss.item()     
            pbar.update(input.shape[0])
        train_loss /= len(train_loader.dataset)                          

        # log losses and ... in a text file
        epoch_msg = f"Epoch: {epoch}"
        loss_msg = f"\nTrain loss: {train_loss}"
        train_loss_history.append(train_loss)
        print(f"{epoch_msg}\t{loss_msg}\n")

        # do checkpointing
        if not os.path.exists(os.path.join(SAVE_DIR, EXP_DIR, 'checkpoints')):
            os.makedirs(os.path.join(SAVE_DIR, EXP_DIR, 'checkpoints'))
        checkpoint_path = os.path.join(SAVE_DIR, EXP_DIR, 'checkpoints', f"checkpoint_last.pth")
        torch.save(model.state_dict(), checkpoint_path)

Epoch 1/1: 100%|██████████| 8000/8000 [03:17<00:00, 40.41img/s]

Epoch: 0	
Train loss: 439.1394096984863






In [None]:
val_loss_history = []
for epoch in range(EPOCHS): 
    with tqdm(total=len(val_data), desc=f"Epoch {epoch + 1}/{EPOCHS}", unit='img') as pbar:
        val_loss = 0
        model.eval()                                                   
        with torch.no_grad():                                          
            for input, target in val_loader:
                input = input.to(device)
                target = target.to(device)                          
                output = model(input)                                   
                loss = criterion(output.flatten(), target.flatten())   
                val_loss += loss.item()
                pbar.update(input.shape[0])                                
        val_loss /= len(val_loader.dataset)                            

        epoch_msg = f"Epoch: {epoch}"
        loss_msg = f"\nVal loss:{val_loss}"
        val_loss_history.append(val_loss)
        print(f"{epoch_msg}\t{loss_msg}\n")


Plot losses

In [None]:
# Diagram
import pandas as pd
loss_df = pd.DataFrame(data={"train_losses": train_loss_history, "val_losses": val_loss_history})

val_acc      = '#2E0249'
train_acc    = '#2F8F9D'
val_loss     = '#F24C4C'
train_loss   = '#EC9B3B'

# fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 5))
loss_df['train_losses'].plot(color=val_loss, legend=True)
loss_df['val_losses'].plot(color=train_loss, legend=True)
# leg = fig.legend(["Train Loss", "Val Loss"], ncol=2, loc='lower center')
# leg.legendHandles[2].set_color(val_loss)
# leg.legendHandles[3].set_color(train_loss)
plot_dir = os.path.join(SAVE_DIR, EXP_DIR, 'plots')
if not os.path.exists(plot_dir):
    os.makedirs(plot_dir)
plt.savefig(os.path.join(plot_dir, 'loss.png'))
loss_df.to_csv(os.path.join(plot_dir, 'loss_df.csv'))

### Visualize the predictions

In [26]:
def collate_test(batch):

    # Last 10 frames are target
    target = np.array(batch)[:,10:]                     
    
    # Add channel dim, scale pixels between 0 and 1, send to GPU
    batch = torch.tensor(batch).unsqueeze(1)          
    batch = batch / 255.0                             
    batch = batch.to(device)                          
    return batch, target

# Test Data Loader
test_loader = DataLoader(test_data,shuffle=True, batch_size=3, collate_fn=collate_test)

# Get a batch
batch, target = next(iter(test_loader))

# Initialize output sequence
output = np.zeros(target.shape, dtype=np.uint8)

# Loop over timesteps
for timestep in range(target.shape[1]):
  input = batch[:,:,timestep:timestep+10]   
  output[:,timestep]=(model(input).squeeze(1).cpu()>0.5)*255.0

In [27]:
for tgt, out in zip(target, output):       # Loop over samples
    
    # Write target video as gif
    with io.BytesIO() as gif:
        imageio.mimsave(gif, tgt, "GIF", fps = 5)    
        target_gif = gif.getvalue()

    # Write output video as gif
    with io.BytesIO() as gif:
        imageio.mimsave(gif, out, "GIF", fps = 5)    
        output_gif = gif.getvalue()

    display(HBox([widgets.Image(value=target_gif), 
                  widgets.Image(value=output_gif)]))

HBox(children=(Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\xff\xff\xff\xfe\xfe\xfe\xfd\xfd\xfd\xfa\xfa\xfa\xf8\…

HBox(children=(Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\xff\xff\xff\xfe\xfe\xfe\xfd\xfd\xfd\xfc\xfc\xfc\xfa\…

HBox(children=(Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\xff\xff\xff\xfe\xfe\xfe\xfd\xfd\xfd\xfc\xfc\xfc\xf8\…