# Testing (latent) predRNN

In [24]:
# Change working directory to be able to import predRNN packages without changing code
import os
from pathlib import Path

current_dir = Path(os.getcwd()).name
if current_dir != 'predrnn':
    os.chdir('./predrnn')
    
# PyTorch (related) imports
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import trange
device = torch.device("cuda" if torch.cuda.is_available() else 'mps' if torch.mps.is_available() else "cpu")
print("Torch device:", device) # Quick check to see if we're using GPU or CPU.

import random
import optuna
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from skimage.metrics import structural_similarity as ssim
from sklearn.model_selection import train_test_split
from IPython.display import clear_output

# PredRNN imports
from argparse import Namespace
from predrnn.core.models.model_factory import Model

# Custom imports
import dataset.download_and_preprocess as dl
from dataset.dataloader import KTHDataset
from autoencoder.autoencoder import AutoencoderModel

# for reproducibility
np.random.seed(42)

Torch device: mps


## Configurations B)

In [15]:
architectures = {
    1: {"encoder" : nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1), # 1x128x128-> 32x128x128
            nn.LeakyReLU(),

            nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1), # 32x128x128 -> 16x128x128
            nn.LeakyReLU(),

            nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1), # 16x128x128 -> 16x64x64
            nn.LeakyReLU(),

            nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1), # 16x64x64 -> 1x64x64
        ),
        "decoder" : nn.Sequential(
            nn.ConvTranspose2d(1, 16, kernel_size=3, stride=1, padding=1), # 1x64x64 -> 16x64x64 
            nn.LeakyReLU(),
            
            nn.ConvTranspose2d(16, 16, kernel_size=3, stride=2, padding=1, output_padding=1), # 16x64x64 -> 16x128x128
            nn.LeakyReLU(),

            nn.ConvTranspose2d(16, 32, kernel_size=3, stride=1, padding=1), # 16x128x128 -> 32x128x128
            nn.LeakyReLU(),

            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=1, padding=1), # 32x128x128 -> 1x128x128
            nn.Sigmoid()
        )
        },

    2: {"encoder" : nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1), # 1x128x128 -> 32x128x128
            nn.LeakyReLU(),

            nn.Conv2d(32, 16, kernel_size=3, stride=2, padding=1), # 32x128x128 -> 16x64x64
            nn.LeakyReLU(),

            nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1), # 16x64x64-> 16x64x64
            nn.LeakyReLU(),

            nn.Conv2d(16, 8, kernel_size=3, stride=2, padding=1), # 16x64x64 -> 8x32x32
            nn.LeakyReLU(),

            nn.Conv2d(8, 1, kernel_size=3, stride=1, padding=1), # 8x32x32-> 1x32x32
        ),
        "decoder" : nn.Sequential(
            nn.ConvTranspose2d(1, 8, kernel_size=3, stride=1, padding=1), # 1x32x32-> 8x32x32
            nn.LeakyReLU(),

            nn.ConvTranspose2d(8, 16, kernel_size=3, stride=2, padding=1, output_padding=1), # 8x32x32 -> 16x64x64
            nn.LeakyReLU(),

            nn.ConvTranspose2d(16, 16, kernel_size=3, stride=1, padding=1), # 16x64x64 -> 16x64x64
            nn.LeakyReLU(),

            nn.ConvTranspose2d(16, 32, kernel_size=3, stride=2, padding=1, output_padding=1), # 16x64x64-> 32x128x128
            nn.LeakyReLU(),

            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=1, padding=1), # 32x128x128 -> 1x128x128
            nn.Sigmoid()
        )
        },
    3: {"encoder" : nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1), # 1x128x128 -> 32x128x128
            nn.LeakyReLU(),

            nn.Conv2d(32, 16, kernel_size=3, stride=2, padding=1), # 32x128x128 -> 16x64x64
            nn.LeakyReLU(),

            nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1), # 1x128x128 -> 32x128x128
            nn.LeakyReLU(),

            nn.Conv2d(16, 8, kernel_size=3, stride=2, padding=1), # 16x64x64 -> 8x32x32
            nn.LeakyReLU(),

            nn.Conv2d(8, 8, kernel_size=3, stride=1, padding=1), # 1x128x128 -> 32x128x128
            nn.LeakyReLU(),

            nn.Conv2d(8, 1, kernel_size=3, stride=2, padding=1), # 8x32x32 -> 1x16x16
        ),
        "decoder" : nn.Sequential(
            nn.ConvTranspose2d(1, 8, kernel_size=3, stride=2, padding=1, output_padding=1), # 1x16x16 -> 8x32x32
            nn.LeakyReLU(),

            nn.ConvTranspose2d(8, 8, kernel_size=3, stride=1, padding=1), # 8x32x32 -> 8x32x32
            nn.LeakyReLU(),

            nn.ConvTranspose2d(8, 16, kernel_size=3, stride=2, padding=1, output_padding=1), # 8x32x32 -> 16x64x64
            nn.LeakyReLU(),

            nn.ConvTranspose2d(16, 16, kernel_size=3, stride=1, padding=1), # 16x64x64 -> 16x64x64
            nn.LeakyReLU(),

            nn.ConvTranspose2d(16, 32, kernel_size=3, stride=2, padding=1, output_padding=1), # 16x64x64 -> 32x128x128
            nn.LeakyReLU(),

            nn.ConvTranspose2d(32, 32, kernel_size=3, stride=1, padding=1), # 32x128x128-> 32x128x128
            nn.LeakyReLU(),

            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=1, padding=1), # 32x128x128 -> 1x128x128
            nn.Sigmoid()
        )
        }
    
}

In [39]:

base_configs_predRNN = {
    'is_training': 1,
    'device': device,
    'model_name': 'predrnn_v2',
    'visual': 0,
    'reverse_input': 1,
    'img_channel': 1,
    'input_length': 10,
    'total_length': 20,
    'filter_size': 5,
    'stride': 1,
    'patch_size': 4,
    'layer_norm': 0,
    'decouple_beta': 0.01,
    'reverse_scheduled_sampling': 1,
    'r_sampling_step_1': 5000,
    'r_sampling_step_2': 50000,
    'r_exp_alpha': 2000,
    'lr': 0.0001,
    'batch_size': 4,
    'max_iterations': 500,
    'display_interval': 100,
    # 'test_interval': 500,
    'snapshot_interval': 100,
    'visual_path': './decoupling_visual',
    # 'pretrained_model': './checkpoints/kth_predrnn_v2/kth_model.ckpt'  # Uncomment if needed
}

In [None]:
root_path = '/Users/maxneerken/Documents/aml/predrnn-pytorch-AML'

In [38]:
latent_16_predrnn_configs = {
    'dataset_name': 'latent',
    'train_data_paths': f'{root_path}/dataset/encoded/16',
    'valid_data_paths': f'{root_path}/dataset/encoded/16',
    'save_dir': 'checkpoints/latent_16/kth_predrnn_v2',
    'gen_frm_dir': 'results/latent_16_kth_predrnn_v2',
    'img_width': 16,
    'num_hidden': '16, 16, 16, 16',  # Using a tuple for multiple values
}

latent_32_predrnn_configs = {
    'dataset_name': 'latent',
    'train_data_paths': f'{root_path}/dataset/encoded/32',
    'valid_data_paths': f'{root_path}/dataset/encoded/32',
    'save_dir': 'checkpoints/latent_32/kth_predrnn_v2',
    'gen_frm_dir': 'results/latent_32_kth_predrnn_v2',
    'img_width': 32,
    'num_hidden': '32, 32, 32, 32',  # Using a tuple for multiple values
}

latent_64_predrnn_configs = {
    'dataset_name': 'latent',
    'train_data_paths': f'{root_path}/dataset/encoded/64',
    'valid_data_paths': f'{root_path}/dataset/encoded/64',
    'save_dir': 'checkpoints/latent_64/kth_predrnn_v2',
    'gen_frm_dir': 'results/latent_64_kth_predrnn_v2',
    'img_width': 64,
    'num_hidden': '64, 64, 64, 64',  # Using a tuple for multiple values
}

## Start evaluating :O

In [42]:
args = Namespace(**(base_configs_predRNN | latent_16_predrnn_configs))

model = Model(args)
model.load('./checkpoints/latent_16/kth_predrnn_v2/model.ckpt-4')


load model: ./checkpoints/latent_16/kth_predrnn_v2/model.ckpt-4


In [45]:
model.network

RNN(
  (MSE_criterion): MSELoss()
  (cell_list): ModuleList(
    (0-3): 4 x SpatioTemporalLSTMCell(
      (conv_x): Sequential(
        (0): Conv2d(16, 112, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
      )
      (conv_h): Sequential(
        (0): Conv2d(16, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
      )
      (conv_m): Sequential(
        (0): Conv2d(16, 48, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
      )
      (conv_o): Sequential(
        (0): Conv2d(32, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
      )
      (conv_last): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
    )
  )
  (conv_last): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (adapter): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
)