In [1]:
# Standard library imports
from argparse import ArgumentParser
import os, sys
THIS_DIR = os.path.abspath('')
PARENT_DIR = os.path.dirname(os.path.abspath(''))
sys.path.append(PARENT_DIR)

# Third party imports
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from torchdiffeq import odeint
from torchvision import utils
import matplotlib.pyplot as plt
import numpy as np

# local application imports
from lag_caVAE.lag import Lag_Net
from lag_caVAE.nn_models import MLP_Encoder, MLP, MLP_Decoder, PSD
from hyperspherical_vae.distributions import VonMisesFisher
from hyperspherical_vae.distributions import HypersphericalUniform
from utils import arrange_data, from_pickle, my_collate, ImageDataset, HomoImageDataset
from examples.cart_lag_cavae_trainer import Model as Model_lag_cavae
from ablations.ablation_cart_MLPdyna_cavae_trainer import Model as Model_MLPdyna_cavae
from ablations.ablation_cart_lag_vae_trainer import Model as Model_lag_vae
from ablations.ablation_cart_lag_MLPEnc_caDec_trainer import Model as Model_lag_MLPEnc_caDec
from ablations.ablation_cart_lag_caEnc_MLPDec_trainer import Model as Model_lag_caEnc_MLPDec
from ablations.ablation_cart_lag_caAE_trainer import Model as Model_lag_caAE
from ablations.HGN import Model as Model_HGN

seed_everything(0)
%matplotlib inline
DPI = 600

In [2]:
checkpoint_path = os.path.join(PARENT_DIR, 
                               'checkpoints', 
                               'cart-lag-cavae-T_p=4-epoch=206.ckpt')
model_lag_cavae = Model_lag_cavae.load_from_checkpoint(checkpoint_path)

checkpoint_path = os.path.join(PARENT_DIR, 
                               'checkpoints', 
                               'ablation-cart-MLPdyna-cavae-T_p=4-epoch=807.ckpt')
model_MLPdyna_cavae = Model_MLPdyna_cavae.load_from_checkpoint(checkpoint_path)

checkpoint_path = os.path.join(PARENT_DIR, 
                               'checkpoints', 
                               'ablation-cart-lag-vae-T_p=4-epoch=987.ckpt')
model_lag_vae = Model_lag_vae.load_from_checkpoint(checkpoint_path)

checkpoint_path = os.path.join(PARENT_DIR, 
                               'checkpoints', 
                               'ablation-cart-lag-MLPEnc-caDec-T_p=4-epoch=524.ckpt')
model_lag_MLPEnc_caDec = Model_lag_MLPEnc_caDec.load_from_checkpoint(checkpoint_path)

# this checkpoint is trained with learning rate 1e-4
checkpoint_path = os.path.join(PARENT_DIR, 
                               'checkpoints', 
                               'ablation-cart-lag-caEnc-MLPDec-T_p=4-epoch=954.ckpt')
model_lag_caEnc_MLPDec = Model_lag_caEnc_MLPDec.load_from_checkpoint(checkpoint_path)

# this checkpoint is trained with learning rate 1e-4
checkpoint_path = os.path.join(PARENT_DIR, 
                               'checkpoints', 
                               'ablation-cart-lag-caAE-T_p=4-epoch=909.ckpt')
model_lag_caAE = Model_lag_caAE.load_from_checkpoint(checkpoint_path)

checkpoint_path = os.path.join(PARENT_DIR, 
                               'checkpoints', 
                               'baseline-cart-HGN-T_p=4-epoch=1777.ckpt')
model_HGN = Model_HGN.load_from_checkpoint(checkpoint_path)

In [3]:
# Load train data, prepare for plotting prediction
# WARNING: this might requires ~18G memory at peak
data_path=os.path.join(PARENT_DIR, 'datasets', 'cartpole-gym-image-dataset-rgb-u9-train.pkl')
train_dataset = HomoImageDataset(data_path, T_pred=4)
# prepare model
model_lag_cavae.t_eval = torch.from_numpy(train_dataset.t_eval)
model_lag_cavae.hparams.annealing = False
model_MLPdyna_cavae.t_eval = torch.from_numpy(train_dataset.t_eval)
model_lag_vae.t_eval = torch.from_numpy(train_dataset.t_eval)
model_lag_MLPEnc_caDec.t_eval = torch.from_numpy(train_dataset.t_eval)
model_lag_caEnc_MLPDec.t_eval = torch.from_numpy(train_dataset.t_eval)
model_lag_caAE.t_eval = torch.from_numpy(train_dataset.t_eval)
model_HGN.t_eval = torch.from_numpy(train_dataset.t_eval)
model_HGN.step = 3 ; model_HGN.alpha = 1

In [4]:
lag_cavae_train_loss = []
MLPdyna_cavae_train_loss = []
lag_MLPEnc_caDec_train_loss = []
lag_caEnc_MLPDec_train_loss = []
lag_vae_train_loss = []
lag_caAE_train_loss = []

for i in range(len(train_dataset.x)):
    train_dataset.u_idx = i
    dataLoader = DataLoader(train_dataset, batch_size=512, shuffle=False, collate_fn=my_collate)
    for batch in dataLoader:
        lag_cavae_train_loss.append(model_lag_cavae.training_step(batch, 0)['log']['recon_loss'].item())
        MLPdyna_cavae_train_loss.append(model_MLPdyna_cavae.training_step(batch, 0)['log']['recon_loss'].item())
        lag_vae_train_loss.append(model_lag_vae.training_step(batch, 0)['log']['recon_loss'].item())
        lag_MLPEnc_caDec_train_loss.append(model_lag_MLPEnc_caDec.training_step(batch, 0)['log']['recon_loss'].item())
        lag_caEnc_MLPDec_train_loss.append(model_lag_caEnc_MLPDec.training_step(batch, 0)['log']['recon_loss'].item())
        lag_caAE_train_loss.append(model_lag_caAE.training_step(batch, 0)['log']['recon_loss'].item())



In [5]:
HGN_train_loss = []
train_dataset.u_idx = 0
dataLoader = DataLoader(train_dataset, batch_size=256, shuffle=False, collate_fn=my_collate)
for batch in dataLoader:
    HGN_train_loss.append(model_HGN.training_step(batch, 0)['log']['recon_loss'].item())

In [6]:
del dataLoader
del train_dataset

In [7]:
# Load train data, prepare for plotting prediction
# WARNING: this might requires ~18G memory at peak
data_path=os.path.join(PARENT_DIR, 'datasets', 'cartpole-gym-image-dataset-rgb-u9-test.pkl')
test_dataset = HomoImageDataset(data_path, T_pred=4)

In [8]:
lag_cavae_test_loss = []
MLPdyna_cavae_test_loss = []
lag_MLPEnc_caDec_test_loss = []
lag_caEnc_MLPDec_test_loss = []
lag_vae_test_loss = []
lag_caAE_test_loss = []

for i in range(len(test_dataset.x)):
    test_dataset.u_idx = i
    dataLoader = DataLoader(test_dataset, batch_size=512, shuffle=False, collate_fn=my_collate)
    for batch in dataLoader:
        lag_cavae_test_loss.append(model_lag_cavae.training_step(batch, 0)['log']['recon_loss'].item())
        MLPdyna_cavae_test_loss.append(model_MLPdyna_cavae.training_step(batch, 0)['log']['recon_loss'].item())
        lag_vae_test_loss.append(model_lag_vae.training_step(batch, 0)['log']['recon_loss'].item())
        lag_MLPEnc_caDec_test_loss.append(model_lag_MLPEnc_caDec.training_step(batch, 0)['log']['recon_loss'].item())
        lag_caEnc_MLPDec_test_loss.append(model_lag_caEnc_MLPDec.training_step(batch, 0)['log']['recon_loss'].item())
        lag_caAE_test_loss.append(model_lag_caAE.training_step(batch, 0)['log']['recon_loss'].item())

In [9]:
HGN_test_loss = []
test_dataset.u_idx = 0
dataLoader = DataLoader(test_dataset, batch_size=256, shuffle=False, collate_fn=my_collate)
for batch in dataLoader:
    HGN_test_loss.append(model_HGN.training_step(batch, 0)['log']['recon_loss'].item())

In [10]:
scale = 64*64*5
print(f'lag_cavae: train: {np.mean(lag_cavae_train_loss)/scale}, test: {np.mean(lag_cavae_test_loss)/scale}')
print(f'MLPdyna_cavae: train: {np.mean(MLPdyna_cavae_train_loss)/scale}, test: {np.mean(MLPdyna_cavae_test_loss)/scale}')
print(f'lag_vae: train: {np.mean(lag_vae_train_loss)/scale}, test: {np.mean(lag_vae_test_loss)/scale}')
print(f'lag_MLPEnc_caDec: train: {np.mean(lag_MLPEnc_caDec_train_loss)/scale}, test: {np.mean(lag_MLPEnc_caDec_test_loss)/scale}')
print(f'lag_caEnc_MLPDec: train: {np.mean(lag_caEnc_MLPDec_train_loss)/scale}, test: {np.mean(lag_caEnc_MLPDec_test_loss)/scale}')
print(f'lag_caAE: train: {np.mean(lag_caAE_train_loss)/scale}, test: {np.mean(lag_caAE_test_loss)/scale}')
print(f'HGN: train: {np.mean(HGN_train_loss)/scale}, test: {np.mean(HGN_test_loss)/scale}')

lag_cavae: train: 0.0067648056842800645, test: 0.007090796773425406
MLPdyna_cavae: train: 0.013915377068850729, test: 0.015017079499860605
lag_vae: train: 0.011116656474769115, test: 0.01259765431491865
lag_MLPEnc_caDec: train: 0.015359965484175417, test: 0.01625758684757683
lag_caEnc_MLPDec: train: 0.011642291510684622, test: 0.013912107857565085
lag_caAE: train: 0.003679761003392438, test: 0.0037456373140836753
HGN: train: 0.0028116617468185723, test: 0.002912916196510196
