# PredRNN++ R&D

A notebook for testing the model code

In [1]:
# System
import sys
import os

In [2]:
# Externals
import torch

In [3]:
# Locals
sys.path.append('..')
from datasets import get_data_loaders
from models.layers import CausalLSTMStack

## Prepare some data

In [4]:
# Data config
data_name = 'moving_mnist'
data_dir = '/global/cfs/cdirs/m1759/sfarrell/nesap-stl/data'
batch_size = 4
patch_size = 4

In [5]:
%%time

# Load a batch of moving-mnist data
train_loader, valid_loader = get_data_loaders(name=data_name, data_dir=data_dir,
                                              batch_size=batch_size, patch_size=patch_size)

CPU times: user 13.4 s, sys: 4.78 s, total: 18.2 s
Wall time: 17.1 s


## Setup the model

In [6]:
# Model config
filter_size = 3
num_hidden = [128, 64, 64, 64, 16]

In [7]:
clstm = CausalLSTMStack(filter_size=filter_size, num_dims=2, channels=num_hidden)
decoder = torch.nn.Conv2d(16, 16, 1, 1)

## Apply model

In [8]:
x = next(train_loader.__iter__())

In [9]:
x.shape

torch.Size([4, 20, 16, 16, 16])

In [10]:
# Initialize hidden states
h, c, m, z = [None]*4
outputs = []

In [11]:
%%time

# Loop over the sequence
for t in range(x.shape[1]):
    h, c, m, z = clstm(x[:,t], h, c, m, z)
    outputs.append(decoder(h[-1])) #.permute(0, -1, 1, 2)

CPU times: user 5.68 s, sys: 1.03 s, total: 6.71 s
Wall time: 3.35 s


In [12]:
outputs[0].shape

torch.Size([4, 16, 16, 16])

In [13]:
torch.stack(outputs, dim=1).shape

torch.Size([4, 20, 16, 16, 16])