In [1]:
# Standard library imports
from argparse import ArgumentParser
import os, sys
THIS_DIR = os.path.abspath('')
PARENT_DIR = os.path.dirname(os.path.abspath(''))
sys.path.append(PARENT_DIR)

# Third party imports
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from torchdiffeq import odeint
from torchvision import utils
import matplotlib.pyplot as plt
import numpy as np

# local application imports
from lag_caVAE.lag import Lag_Net
from lag_caVAE.nn_models import MLP_Encoder, MLP, MLP_Decoder, PSD
from hyperspherical_vae.distributions import VonMisesFisher
from hyperspherical_vae.distributions import HypersphericalUniform
from utils import arrange_data, from_pickle, my_collate, ImageDataset, HomoImageDataset
from examples.pend_lag_cavae_trainer import Model as Model_lag_cavae
from ablations.ablation_pend_MLPdyna_cavae_trainer import Model as Model_MLPdyna_cavae
from ablations.ablation_pend_lag_vae_trainer import Model as Model_lag_vae
from ablations.ablation_pend_lag_caAE_trainer import Model as Model_lag_caAE
from ablations.HGN import Model as Model_HGN

seed_everything(0)
%matplotlib inline
DPI = 600

In [2]:
checkpoint_path = os.path.join(PARENT_DIR, 
                               'checkpoints', 
                               'pend-lag-cavae-T_p=4-epoch=701.ckpt')
model_lag_cavae = Model_lag_cavae.load_from_checkpoint(checkpoint_path)

checkpoint_path = os.path.join(PARENT_DIR, 
                               'checkpoints', 
                               'ablation-pend-MLPdyna-cavae-T_p=4-epoch=919.ckpt')
model_MLPdyna_cavae = Model_MLPdyna_cavae.load_from_checkpoint(checkpoint_path)

checkpoint_path = os.path.join(PARENT_DIR, 
                               'checkpoints', 
                               'ablation-pend-lag-vae-T_p=4-epoch=916.ckpt')
model_lag_vae = Model_lag_vae.load_from_checkpoint(checkpoint_path)

checkpoint_path = os.path.join(PARENT_DIR, 
                               'checkpoints', 
                               'ablation-pend-lag-caAE-T_p=4-epoch=778.ckpt')
model_lag_caAE = Model_lag_caAE.load_from_checkpoint(checkpoint_path)

checkpoint_path = os.path.join(PARENT_DIR, 
                               'checkpoints', 
                               'baseline-pend-HGN-T_p=4-epoch=1543.ckpt')
model_HGN = Model_HGN.load_from_checkpoint(checkpoint_path)

In [3]:
# Load train data, prepare for plotting prediction
data_path=os.path.join(PARENT_DIR, 'datasets', 'pendulum-gym-image-dataset-train.pkl')
train_dataset = HomoImageDataset(data_path, T_pred=4)
# prepare model
model_lag_cavae.t_eval = torch.from_numpy(train_dataset.t_eval)
model_lag_cavae.hparams.annealing = False
model_MLPdyna_cavae.t_eval = torch.from_numpy(train_dataset.t_eval)
model_lag_vae.t_eval = torch.from_numpy(train_dataset.t_eval)
model_lag_caAE.t_eval = torch.from_numpy(train_dataset.t_eval)
model_HGN.t_eval = torch.from_numpy(train_dataset.t_eval)
model_HGN.step = 3 ; model_HGN.alpha = 1

In [4]:
lag_cavae_train_loss = []
MLPdyna_cavae_train_loss = []
lag_vae_train_loss = []
lag_caAE_train_loss = []

for i in range(len(train_dataset.x)):
    batch = (torch.from_numpy(train_dataset.x[i]), torch.from_numpy(train_dataset.u[i]))
    lag_cavae_train_loss.append(model_lag_cavae.training_step(batch, 0)['log']['recon_loss'].item())
    MLPdyna_cavae_train_loss.append(model_MLPdyna_cavae.training_step(batch, 0)['log']['recon_loss'].item())
    lag_vae_train_loss.append(model_lag_vae.training_step(batch, 0)['log']['recon_loss'].item())
    lag_caAE_train_loss.append(model_lag_caAE.training_step(batch, 0)['log']['recon_loss'].item())




In [5]:
HGN_train_loss = []
train_dataset.u_idx = 0
dataLoader = DataLoader(train_dataset, batch_size=256, shuffle=False, collate_fn=my_collate)
for batch in dataLoader:
    HGN_train_loss.append(model_HGN.training_step(batch, 0)['log']['recon_loss'].item())

In [6]:
# Load data, prepare for plotting prediction
data_path=os.path.join(PARENT_DIR, 'datasets', 'pendulum-gym-image-dataset-test.pkl')
test_dataset = HomoImageDataset(data_path, 4)

In [7]:
lag_cavae_test_loss = []
MLPdyna_cavae_test_loss = []
lag_vae_test_loss = []
lag_caAE_test_loss = []

for i in range(len(train_dataset.x)):
    batch = (torch.from_numpy(test_dataset.x[i]), torch.from_numpy(test_dataset.u[i]))
    lag_cavae_test_loss.append(model_lag_cavae.training_step(batch, 0)['log']['recon_loss'].item())
    MLPdyna_cavae_test_loss.append(model_MLPdyna_cavae.training_step(batch, 0)['log']['recon_loss'].item())
    lag_vae_test_loss.append(model_lag_vae.training_step(batch, 0)['log']['recon_loss'].item())
    lag_caAE_test_loss.append(model_lag_caAE.training_step(batch, 0)['log']['recon_loss'].item())

In [8]:
HGN_test_loss = []
train_dataset.u_idx = 0
dataLoader = DataLoader(test_dataset, batch_size=256, shuffle=False, collate_fn=my_collate)
for batch in dataLoader:
    HGN_test_loss.append(model_HGN.training_step(batch, 0)['log']['recon_loss'].item())

In [9]:
scale = 32*32*5
print(f'lag_cavae: train: {np.mean(lag_cavae_train_loss)/scale}, test: {np.mean(lag_cavae_test_loss)/scale}')
print(f'MLPdyna_cavae: train: {np.mean(MLPdyna_cavae_train_loss)/scale}, test: {np.mean(MLPdyna_cavae_test_loss)/scale}')
print(f'lag_vae: train: {np.mean(lag_vae_train_loss)/scale}, test: {np.mean(lag_vae_test_loss)/scale}')
print(f'lag_caAE: train: {np.mean(lag_caAE_train_loss)/scale}, test: {np.mean(lag_caAE_test_loss)/scale}')
print(f'HGN: train: {np.mean(HGN_train_loss)/scale}, test: {np.mean(HGN_test_loss)/scale}')

lag_cavae: train: 0.0018338330462574957, test: 0.0018673422187566757
MLPdyna_cavae: train: 0.0018255940079689025, test: 0.001863210164010525
lag_vae: train: 0.002403554953634739, test: 0.002519616447389126
lag_caAE: train: 0.001860453300178051, test: 0.00189580075442791
HGN: train: 0.0005488340655574575, test: 0.000710727070691064
