# Test CNN-VAE (reduction) with Fully-Connected Net (prediction)

In [None]:
import os
from os.path import join
import sys
from pathlib import Path

# include app directory into sys.path
parent_dir = Path(os.path.abspath('')).parent
app_dir = join(parent_dir, "app")
if app_dir not in sys.path:
      sys.path.append(app_dir)

import torch as pt
import matplotlib.pyplot as plt

import utils.config as config
from utils.CNN_VAE import make_VAE_model
from utils.FullyConnected import make_FC_model

plt.rcParams["figure.dpi"] = 180

# use GPU if possible
device = pt.device("cuda") if pt.cuda.is_available() else pt.device("cpu")
print(device)

VAE_PATH = join(parent_dir, "output", "VAE", "latent_study", config.VAE_model)
FC_MODEL = "45_256_1"
FC_PATH = join(parent_dir, "output", "single_flow_cond", "parameter_study", "pred_horizon_1")
LATENT_SIZE = config.VAE_latent_size
OUTPUT_PATH = join(parent_dir, "output", "single_flow_cond")

#### Load Convolutional VAE

In [None]:
autoencoder = make_VAE_model(
    n_latent=LATENT_SIZE, 
    device=device)
autoencoder.load(VAE_PATH)
autoencoder.eval()

#### Load Fully-Connected Net

In [None]:
# results from parameter study
INPUT_WIDTH = 45
HIDDEN_SIZE = 256
N_HIDDEN_LAYERS = 1

FC_model = make_FC_model(
    latent_size=LATENT_SIZE,
    input_width=INPUT_WIDTH, 
    hidden_size=HIDDEN_SIZE, 
    n_hidden_layers=N_HIDDEN_LAYERS
)
print(FC_PATH)
FC_model.load_state_dict(pt.load(join(FC_PATH, FC_MODEL + ".pt")))
FC_model.eval()

In [None]:
study_results = pt.load(join(FC_PATH, "study_results.pt"))
loss_df = study_results[FC_MODEL][0]
print(loss_df)

#### Latent Loss vs. Full Space Loss

In [None]:
# compare the latent loss vs the full space MSE

#### Original vs predicted Snapshot

In [None]:
# compare original vs predicted snapshot on test data with subplots(1, 3)[Orig, Predicted, MSE] 

#### Loss vs. Prediction Horizon

In [None]:
# show how the loss of the selected model configuration changes when the prediction horizon increases

#### Test AR prediction

In [None]:
# decide on architecture and predict an arbitrary timestep -> compare to actual timestep