# 1. Handling imports

In [6]:
import sys
import os

# Add 'predrnn' to sys.path so imports like 'from core...' work
project_root = os.path.abspath("predrnn")  # or use full path if needed
if project_root not in sys.path:
    sys.path.insert(0, project_root)



# PyTorch (related) imports
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import trange
device = torch.device("cuda" if torch.cuda.is_available() else 'mps' if torch.mps.is_available() else "cpu")
print("Torch device:", device) # Quick check to see if we're using GPU or CPU.


import numpy as np
import matplotlib.pyplot as plt
import pandas as pd


from skimage.metrics import structural_similarity as ssim
from sklearn.model_selection import train_test_split
from pathlib import Path
from IPython.display import clear_output

# Custom imports
import dataset.download_and_preprocess as dl
from dataset.dataloader import KTHDataset
from autoencoder.autoencoder import AutoencoderModel, architectures

# predrnn imports
from predrnn.core.models.predrnn_v2 import RNN as predrnn
from predrnn.configparser import build_predrnn_args

# for reproducibility
np.random.seed(42)


Torch device: cuda


# 2. handling file paths

In [7]:
autoencoder_root_path = Path("autoencoder") / "models"
predrnn_root_path = Path("predrnn") / "checkpoints"
predrnn_base_data_dir = Path("dataset") / 'encoded' 

# 3. Loading the trained models
## A. Autoencoders

In [8]:
# 64x64 autencoder
autencoder_64 =  AutoencoderModel(epochs=200, encoder=architectures[1]["encoder"], decoder=architectures[1]["decoder"]).to(device)
autencoder_64.load(autoencoder_root_path / "model_trial_2_1.pt", device=device)
autencoder_64.eval()

# 32x32 autencoder
autoencoder_32 =  AutoencoderModel(epochs=200, encoder=architectures[2]["encoder"], decoder=architectures[2]["decoder"]).to(device)
autoencoder_32.load(autoencoder_root_path / "model_trial_1_2.pt", device=device)
autoencoder_32.eval()

# 16x16 autencoder
autoencoder_16 =  AutoencoderModel(epochs=200, encoder=architectures[3]["encoder"], decoder=architectures[3]["decoder"]).to(device)
autoencoder_16.load(autoencoder_root_path / "model_trial_0_3.pt", device=device)
autoencoder_16.eval();

## B. PredRNN

In [11]:
# 128x128 (native) predRNN model
trained_model = torch.load(predrnn_root_path / "kth_predrnn_vanilla" / "model.ckpt-400")
config = build_predrnn_args(img_width=128, 
                            data_dir=str((predrnn_base_data_dir / '64').resolve()), 
                            result_checkpoint_dir='kth_predrnn_vanilla')
num_hidden = [int(x) for x in config.num_hidden.split(',')]
num_layers = len(num_hidden)
predrnn_native = predrnn(num_layers, num_hidden, config)
predrnn_native.load_state_dict(trained_model['net_param'])



# 64x64 PredRNN
trained_model = torch.load(predrnn_root_path / "kth_predrnn_64" / "model.ckpt-400")
config = build_predrnn_args(img_width=64, 
                            data_dir=str((predrnn_base_data_dir / '64').resolve()), 
                            result_checkpoint_dir='kth_predrnn_64')
num_hidden = [int(x) for x in config.num_hidden.split(',')]
num_layers = len(num_hidden)
predrnn_64 = predrnn(num_layers, num_hidden, config)
predrnn_64.load_state_dict(trained_model['net_param'])


# 32x32 PredRNN
trained_model = torch.load(predrnn_root_path / "kth_predrnn_32" / "model.ckpt-400")
config = build_predrnn_args(img_width=32, 
                            data_dir=str((predrnn_base_data_dir / '32').resolve()), 
                            result_checkpoint_dir='kth_predrnn_32')
num_hidden = [int(x) for x in config.num_hidden.split(',')]
num_layers = len(num_hidden)
predrnn_32 = predrnn(num_layers, num_hidden, config)
predrnn_32.load_state_dict(trained_model['net_param'])

# 16x16 PredRNN
trained_model = torch.load(predrnn_root_path / "kth_predrnn_16" / "model.ckpt-400")
config = build_predrnn_args(img_width=16, 
                            data_dir=str((predrnn_base_data_dir / '16').resolve()), 
                            result_checkpoint_dir='kth_predrnn_16')
num_hidden = [int(x) for x in config.num_hidden.split(',')]
num_layers = len(num_hidden)
predrnn_16 = predrnn(num_layers, num_hidden, config)
predrnn_16.load_state_dict(trained_model['net_param']);