In [8]:
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader

import io
import imageio
from ipywidgets import widgets, HBox

# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [9]:
MovingMNIST = np.load('mnist_test_seq.npy').transpose(1, 0, 2, 3)#convert to numpy array
np.random.shuffle(MovingMNIST)# Shuffle Data

In [10]:
# Train, Test, Validation splits
train_data = MovingMNIST[:800]         
val_data = MovingMNIST[800:900]       
test_data = MovingMNIST[900:1000]    

def collate(batch):

    # Add channel dim, scale pixels between 0 and 1, send to GPU
    batch = torch.tensor(batch).unsqueeze(1)     
    batch = batch / 255.0                        
    batch = batch.to(device)                     

    # Randomly pick 10 frames as input, 11th frame is target
    rand = np.random.randint(10,20)                     
    return batch[:,:,rand-10:rand], batch[:,:,rand]     

In [31]:
# Training Data Loader
BATCH_SIZE = 8
train_loader = DataLoader(train_data, shuffle=True, 
                        batch_size=BATCH_SIZE, collate_fn=collate)

# Validation Data Loader
val_loader = DataLoader(val_data, shuffle=True, 
                        batch_size=BATCH_SIZE, collate_fn=collate)

# Train_loader es un arreglo de 100 batches donde cada batch tiene tamaño de 8(batch size)
# Cada batch de un loader es una tupla, donde un elemento tiene una dimension de 8,1,10,64,64 y otro de 8,1,64,64
for i in train_loader:
    print(i[0].shape, ",", i[1].shape)
#for i in val_loader:
#    print(i[0].shape, ",", i[1].shape)


torch.Size([8, 1, 10, 64, 64]) , torch.Size([8, 1, 64, 64])
torch.Size([8, 1, 10, 64, 64]) , torch.Size([8, 1, 64, 64])
torch.Size([8, 1, 10, 64, 64]) , torch.Size([8, 1, 64, 64])
torch.Size([8, 1, 10, 64, 64]) , torch.Size([8, 1, 64, 64])
torch.Size([8, 1, 10, 64, 64]) , torch.Size([8, 1, 64, 64])
torch.Size([8, 1, 10, 64, 64]) , torch.Size([8, 1, 64, 64])
torch.Size([8, 1, 10, 64, 64]) , torch.Size([8, 1, 64, 64])
torch.Size([8, 1, 10, 64, 64]) , torch.Size([8, 1, 64, 64])
torch.Size([8, 1, 10, 64, 64]) , torch.Size([8, 1, 64, 64])
torch.Size([8, 1, 10, 64, 64]) , torch.Size([8, 1, 64, 64])
torch.Size([8, 1, 10, 64, 64]) , torch.Size([8, 1, 64, 64])
torch.Size([8, 1, 10, 64, 64]) , torch.Size([8, 1, 64, 64])
torch.Size([8, 1, 10, 64, 64]) , torch.Size([8, 1, 64, 64])
torch.Size([8, 1, 10, 64, 64]) , torch.Size([8, 1, 64, 64])
torch.Size([8, 1, 10, 64, 64]) , torch.Size([8, 1, 64, 64])
torch.Size([8, 1, 10, 64, 64]) , torch.Size([8, 1, 64, 64])
torch.Size([8, 1, 10, 64, 64]) , torch.S

In [32]:
import time
import cv2
input, _ = next(iter(val_loader))
input = input.cpu().numpy() * 255.0#reconvert to numpy array 
print(input.shape)
for (i, video) in enumerate(input.squeeze(1)[:4]):# Shape of (16,1,10,64,64)->after squeeze (16,10,64,64)->after selection (4,10,64,64)
    # for frame in video:
    #     cv2.namedWindow("Video"+str(i), cv2.WINDOW_NORMAL)
    #     cv2.imshow("Video"+str(i), frame)
    #     cv2.waitKey(200)
    # cv2.destroyAllWindows() # close the window
        
    with io.BytesIO() as gif:
       imageio.mimsave(gif,video.astype(np.uint8),"GIF",fps=5)
       display(HBox([widgets.Image(value=gif.getvalue())]))

(8, 1, 10, 64, 64)


HBox(children=(Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\xff\xff\xff\xfe\xfe\xfe\xfd\xfd\xfd\xfc\xfc\xfc\xf8\…

HBox(children=(Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\xff\xff\xff\xfe\xfe\xfe\xfd\xfd\xfd\xfc\xfc\xfc\xfb\…

HBox(children=(Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\xff\xff\xff\xfe\xfe\xfe\xfd\xfd\xfd\xfc\xfc\xfc\xfb\…

HBox(children=(Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\xff\xff\xff\xfe\xfe\xfe\xfd\xfd\xfd\xfa\xfa\xfa\xf9\…

**MODEL**

In [33]:
import torch.nn as nn
import torch
from ConvLSTM import ConvLSTM


class Seq2Seq(nn.Module):

    def __init__(self, num_channels, num_kernels, kernel_size, padding, 
                activation, frame_size, num_layers):

        super(Seq2Seq, self).__init__()

        self.sequential = nn.Sequential()

        # Add First layer (Different in_channels than the rest)
        self.sequential.add_module(
            "convlstm1", ConvLSTM(
                in_channels=num_channels, out_channels=num_kernels,
                kernel_size=kernel_size, padding=padding, 
                activation=activation, frame_size=frame_size)
        )

        self.sequential.add_module(
            "batchnorm1", nn.BatchNorm3d(num_features=num_kernels)
        ) 

        # Add rest of the layers
        for l in range(2, num_layers+1):

            self.sequential.add_module(
                f"convlstm{l}", ConvLSTM(
                    in_channels=num_kernels, out_channels=num_kernels,
                    kernel_size=kernel_size, padding=padding, 
                    activation=activation, frame_size=frame_size)
                )
                
            self.sequential.add_module(
                f"batchnorm{l}", nn.BatchNorm3d(num_features=num_kernels)
                ) 

        # Add Convolutional Layer to predict output frame
        self.conv = nn.Conv2d(
            in_channels=num_kernels, out_channels=num_channels,
            kernel_size=kernel_size, padding=padding)

    def forward(self, X):

        # Forward propagation through all the layers
        output = self.sequential(X)

        # Return only the last output frame
        output = self.conv(output[:,:,-1])
        
        return nn.Sigmoid()(output)

**OPTIMIZERS**

In [34]:
# The input video frames are grayscale, thus single channel
model = Seq2Seq(num_channels=1, num_kernels=64, 
kernel_size=(3, 3), padding=(1, 1), activation="relu", 
frame_size=(64, 64), num_layers=3).to(device)

optim = Adam(model.parameters(), lr=1e-4)

# Binary Cross Entropy, target pixel values either 0 or 1
criterion = nn.BCELoss(reduction='sum')

**Training**

In [35]:
num_epochs = 20

for epoch in range(1, num_epochs+1):
    
    train_loss = 0                                                 
    model.train()                                                  
    for batch_num, (input, target) in enumerate(train_loader, 1):  
        output = model(input)                                     
        loss = criterion(output.flatten(), target.flatten())       
        loss.backward()                                            
        optim.step()                                               
        optim.zero_grad()                                           
        train_loss += loss.item()                                 
    train_loss /= len(train_loader.dataset)                       

    val_loss = 0                                                 
    model.eval()                                                   
    with torch.no_grad():                                          
        for input, target in val_loader:                          
            output = model(input)                                   
            loss = criterion(output.flatten(), target.flatten())   
            val_loss += loss.item()                                
    val_loss /= len(val_loader.dataset)                            

    print("Epoch:{} Training Loss:{:.2f} Validation Loss:{:.2f}\n".format(
        epoch, train_loss, val_loss))

Epoch:1 Training Loss:552.13 Validation Loss:439.25

Epoch:2 Training Loss:346.84 Validation Loss:320.41

Epoch:3 Training Loss:324.03 Validation Loss:309.77

Epoch:4 Training Loss:315.89 Validation Loss:309.21

Epoch:5 Training Loss:309.74 Validation Loss:300.58

Epoch:6 Training Loss:307.77 Validation Loss:298.36

Epoch:7 Training Loss:305.17 Validation Loss:309.01

Epoch:8 Training Loss:303.62 Validation Loss:304.20

Epoch:9 Training Loss:305.08 Validation Loss:295.47

Epoch:10 Training Loss:295.23 Validation Loss:305.70

Epoch:11 Training Loss:295.99 Validation Loss:289.10

Epoch:12 Training Loss:292.61 Validation Loss:310.06

Epoch:13 Training Loss:293.83 Validation Loss:292.67

Epoch:14 Training Loss:289.98 Validation Loss:290.01

Epoch:15 Training Loss:294.45 Validation Loss:278.42

Epoch:16 Training Loss:283.89 Validation Loss:281.65

Epoch:17 Training Loss:285.84 Validation Loss:274.57

Epoch:18 Training Loss:284.70 Validation Loss:290.15

Epoch:19 Training Loss:283.20 Validat

In [None]:
def collate_test(batch):

    # Last 10 frames are target
    target = np.array(batch)[:,10:]                     
    
    # Add channel dim, scale pixels between 0 and 1, send to GPU
    batch = torch.tensor(batch).unsqueeze(1)          
    batch = batch / 255.0                             
    batch = batch.to(device)                          
    return batch, target

# Test Data Loader
test_loader = DataLoader(test_data,shuffle=True, 
                         batch_size=3, collate_fn=collate_test)

# Get a batch
batch, target = next(iter(test_loader))

# Initialize output sequence
output = np.zeros(target.shape, dtype=np.uint8)

# Loop over timesteps
for timestep in range(target.shape[1]):
  input = batch[:,:,timestep:timestep+10]   
  output[:,timestep]=(model(input).squeeze(1).cpu()>0.5)*255.0

In [None]:
for tgt, out in zip(target, output):       # Loop over samples
    
    # Write target video as gif
    with io.BytesIO() as gif:
        imageio.mimsave(gif, tgt, "GIF", fps = 5)    
        target_gif = gif.getvalue()

    # Write output video as gif
    with io.BytesIO() as gif:
        imageio.mimsave(gif, out, "GIF", fps = 5)    
        output_gif = gif.getvalue()

    display(HBox([widgets.Image(value=target_gif), 
                  widgets.Image(value=output_gif)]))

HBox(children=(Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\xff\xff\xff\xfe\xfe\xfe\xfd\xfd\xfd\xfc\xfc\xfc\xfb\…

HBox(children=(Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\xff\xff\xff\xfe\xfe\xfe\xfd\xfd\xfd\xfc\xfc\xfc\xf9\…

HBox(children=(Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\xff\xff\xff\xfe\xfe\xfe\xfd\xfd\xfd\xfc\xfc\xfc\xfb\…