# Test CNN-VAE (reduction) with Fully-Connected Net (prediction)

In [None]:
import os
from os.path import join
import sys
from pathlib import Path

# include app directory into sys.path
parent_dir = Path(os.path.abspath('')).parent
app_dir = join(parent_dir, "app")
if app_dir not in sys.path:
      sys.path.append(app_dir)

import torch as pt
from torch.nn.functional import mse_loss
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

import utils.config as config
from utils.helper_funcs import find_target_index_in_dataset
from utils.CNN_VAE import make_VAE_model
from utils.FullyConnected import make_FC_model
from utils.Scaler import MinMaxScaler_1_1
from utils.DataWindow import DataWindow

plt.rcParams["figure.dpi"] = 180

# use GPU if possible
device = pt.device("cuda") if pt.cuda.is_available() else pt.device("cpu")
print(device)

DATA_PATH = join(parent_dir, "data", "single_flow_cond")
VAE_PATH = join(parent_dir, "output", "VAE", "latent_study", config.VAE_model)
FC_MODEL = "45_256_1"
FC_PATH = join(parent_dir, "output", "VAE_FC", "param_study", "pred_horizon_1")
LATENT_SIZE = config.VAE_latent_size
OUTPUT_PATH = join(parent_dir, "output", "VAE_FC", "param_study")

TIMESTEP = config.timestep_reconstruction_single
TIMESTEP_dimless = (TIMESTEP * config.U_inf) / (config.c_mean * config.timesteps_per_second)
TIMESTEP = int(TIMESTEP - config.time_steps_per_cond * config.single_flow_cond_train_share) # to receive index of test data subtract number of samples in train data
print(TIMESTEP_dimless)

#### Load data (single flow condition at alpha = 4.00)

In [None]:
# load datasets
train = pt.load(join(DATA_PATH, "VAE_train.pt"))
test = pt.load(join(DATA_PATH, "VAE_test.pt"))
print(train.max(), train.min())
print(test.max(), test.min())
print(train.shape)
print(test.shape)

#### Load autoencoder and encode data

In [None]:
# load pre-trained autoencoder model
autoencoder = make_VAE_model(
    n_latent=LATENT_SIZE, 
    device=device)
autoencoder.load(VAE_PATH)
autoencoder.eval()
decoder = autoencoder._decoder

# encode datasets
train_enc = autoencoder.encode_dataset(train)
test_enc = autoencoder.encode_dataset(test)
print(test_enc.shape)

#### Load Fully-Connected Net

In [None]:
# results from parameter study
INPUT_WIDTH = 45
HIDDEN_SIZE = 256
N_HIDDEN_LAYERS = 1

FC_model = make_FC_model(
    latent_size=LATENT_SIZE,
    input_width=INPUT_WIDTH, 
    hidden_size=HIDDEN_SIZE, 
    n_hidden_layers=N_HIDDEN_LAYERS
)
FC_model.load(join(FC_PATH, FC_MODEL + ".pt"))
FC_model.eval()

#### Scale data, create Data Window and load into Time-Series TensorDataset to create input-target pairs

In [None]:
latent_scaler = pt.load(join(FC_PATH, "scaler.pt"))
data_window = DataWindow(train=latent_scaler.scale(train_enc), test=latent_scaler.scale(test_enc), input_width=INPUT_WIDTH, pred_horizon=1)
_, target_idx = data_window.rolling_window(test_enc.shape[1])
target_idx = target_idx.tolist()
test_enc = data_window.test_dataset

In [None]:
study_results = pt.load(join(FC_PATH, "study_results.pt"))
loss_df = study_results[FC_MODEL][0]
print(loss_df)

#### Latent Loss vs. Full Space Loss

In [None]:
# compare the latent loss vs the full space MSE predict the whole test dataset and store latent and original loss
latent_loss = []
orig_loss = []

# iterate over the windows of data
for i, (inputs_latent, targets_latent) in enumerate(test_enc):
    inputs_latent = inputs_latent.flatten().to(device)
    targets_latent = targets_latent.flatten().to(device)
    
    with pt.no_grad():
        # time evolution in latent space
        pred_latent = FC_model(inputs_latent)
        latent_loss.append(mse_loss(targets_latent, pred_latent))

        # re-scaling
        pred_latent_rescaled = latent_scaler.rescale(pred_latent)

        # decoding to full space
        pred_orig = decoder(pred_latent_rescaled.unsqueeze(0)).squeeze().detach()
        orig_loss.append(mse_loss(test[:, :, target_idx[i][-1]], pred_orig))

In [None]:
timesteps = [((t +  config.time_steps_per_cond * config.single_flow_cond_train_share ) * config.U_inf) / (config.c_mean * config.timesteps_per_second) for t in range(target_idx[0][-1], target_idx[-1][-1] + 1)]
plt.plot(timesteps, latent_loss, label="Latent Loss")
plt.plot(timesteps, orig_loss, label="Orig Loss")
plt.ylabel("MSE")
plt.xlabel(rf"$\tau$")
plt.yscale("log")
plt.legend()

#### Original vs predicted Snapshot

In [None]:
# compare original vs predicted snapshot on test data with subplots(1, 3)[Orig, Predicted, MSE] 
coords = pt.load(join(Path(DATA_PATH).parent, "coords_interp.pt"))
xx, yy = coords
fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
vmin_cp, vmax_cp = config.plot_lims_cp
vmin_MSE, vmax_MSE = config.plot_lims_MSE_reconstruction
levels_cp = pt.linspace(vmin_cp, vmax_cp, 120)
levels_MSE = pt.linspace(vmin_MSE, vmax_MSE, 120)

# find index of input-target pair in test_enc that predicts TIMESTEP
dataset_id = find_target_index_in_dataset(nested_list=target_idx, target_id=TIMESTEP)

with pt.no_grad():
    inputs, targets = test_enc[dataset_id]
    inputs_latent = inputs_latent.flatten().to(device)
    targets_latent = targets_latent.flatten().to(device)
    # time evolution in latent space
    pred_latent = FC_model(inputs_latent)
    # re-scaling
    pred_latent_rescaled = latent_scaler.rescale(pred_latent)
    # decoding to full space
    pred_orig = decoder(pred_latent_rescaled.unsqueeze(0)).squeeze().detach()

MSE = (test[:, :, TIMESTEP] - pred_orig)**2

ax1.contourf(xx, yy, test[:, :, TIMESTEP], vmin=vmin_cp, vmax=vmax_cp, levels=levels_cp)
ax2.contourf(xx, yy, pred_orig, vmin=vmin_cp, vmax=vmax_cp, levels=levels_cp)
cont = ax3.contourf(xx, yy, MSE, vmin=vmin_MSE, vmax=vmax_MSE, levels=levels_MSE)

ax1.set_title("Ground Truth")
ax2.set_title("CNN-VAE-FC")

fig.subplots_adjust(right=0.95)
cax = fig.add_axes([0.99, 0.283, 0.03, 0.424])
cbar = fig.colorbar(cont, cax=cax,label = "Squarred Error")
cbar.formatter = ticker.FormatStrFormatter(f'%.{2}f')

for ax in [ax1, ax2, ax3]:
    ax.set_aspect("equal")
    ax.set_xticklabels([])
    ax.set_yticklabels([])


#### Loss vs. Prediction Horizon

In [None]:
# show how the loss of the selected model configuration changes when the prediction horizon increases

#### Test AR prediction

In [None]:
# decide on architecture and predict an arbitrary timestep -> compare to actual timestep