## Imports

In [None]:
import os
from os.path import join
import sys
from pathlib import Path

# include app directory into sys.path
parent_dir = Path(os.path.abspath('')).parent
app_dir = join(parent_dir, "app")
if app_dir not in sys.path:
      sys.path.append(app_dir)

import torch as pt
from torch.nn.functional import mse_loss
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

import utils.config as config
from utils.helper_funcs import find_target_index_in_dataset
from FC.FullyConnected import make_FC_model
from utils.helper_funcs import shift_input_sequence, reduce_datasets_SVD, reduce_datasets_VAE
from utils.DataWindow import DataWindow

plt.rcParams["figure.dpi"] = 180

# use GPU if possible
device = pt.device("cuda") if pt.cuda.is_available() else pt.device("cpu")
print(device)

# define prediction horizon and type of dimensionality reduction
PRED_HORIZON = 2
DIM_REDUCTION = "VAE"       # one of ("SVD" / "VAE")
N_LATENT = config.SVD_rank if DIM_REDUCTION == "SVD" else config.VAE_latent_size

# define paths
DATA_PATH = join(parent_dir, "data", "full_pipeline_data")
VAE_PATH = join(parent_dir, "output", "VAE", "latent_study", config.VAE_model)
SVD_PATH = join(parent_dir, "output", "SVD", "U.pt")
FC_MODEL = "32_128_1"
FC_PATH = join(parent_dir, "output", "FC", DIM_REDUCTION, "param_study", f"pred_horizon_{PRED_HORIZON}")
OUTPUT_PATH = join(parent_dir, "output", "FC", DIM_REDUCTION, "param_study")

#### Plot Parameter Study Results

In [None]:
# load study results
study_results = pt.load(join(FC_PATH, "study_results.pt"))
param_combinations = list(study_results.keys())

# find parameter combinations of study and extract test loss
input_width = [int(param_set.split('_')[0]) for param_set in param_combinations]
hidden_size = [int(param_set.split('_')[1]) for param_set in param_combinations]
n_hidden = [int(param_set.split('_')[2]) for param_set in param_combinations]
test_losses = [study_results[param_set][0]["test_loss"].values[-10:].mean() for param_set in param_combinations]

# Sort the indexed losses based on the values (ascending order)
sorted_losses = sorted(list(enumerate(test_losses)), key=lambda x: x[1])
lowest_loss_idx = [index for index, _ in sorted_losses[:5]]

print("The param combinations with the lowest loss: [input_width, hidden_size, n_hidden]")
print([param_combinations[i] for i in lowest_loss_idx]) 

In [None]:
# Create a 3D scatter plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(input_width, hidden_size, n_hidden, c=test_losses, cmap='viridis', s=100)
ax.set_xlabel("input width")
ax.set_xticks([32])
ax.set_ylabel("hidden layer neurons")
ax.set_zlabel("hidden layers")

cbar = plt.colorbar(ax.scatter(input_width, hidden_size, n_hidden, c=test_losses, cmap='viridis'), pad=0.15)
cbar.set_label('Test Loss')
fig.tight_layout()
fig.savefig(join(OUTPUT_PATH, f"{DIM_REDUCTION}_FC_predhor{PRED_HORIZON}_param_study.png"), bbox_inches="tight")

## Pipeline Pre-Processing

In [None]:
# timestep computation (test datasat comprises two flow conditions, 500 timesteps each)
TIMESTEP_1 = config.timestep_prediction                  
TIMESTEP_2 = TIMESTEP_1 + config.time_steps_per_cond
print(f"Predicted timestep for each flow condition is {TIMESTEP_1} out of {config.time_steps_per_cond}")

In [None]:
# load experimental test data
test_data_orig = pt.load(join(DATA_PATH, f"{DIM_REDUCTION}_test.pt"))

# If using SVD, unflatten the data
if DIM_REDUCTION == "SVD":
    test_data_orig = test_data_orig.unflatten(dim=0, sizes=config.target_resolution)


# load coordinate grids
coords = pt.load(join(Path(DATA_PATH).parent, "coords_interp.pt"))
xx, yy = coords

# compress dataset into reduced state either by VAE or SVD
if DIM_REDUCTION == "VAE":
    (train_red, val_red, test_red), decoder = reduce_datasets_VAE(DATA_PATH, VAE_PATH, OUTPUT_PATH, device) 
elif DIM_REDUCTION == "SVD":
    (train_red, val_red, test_red), U = reduce_datasets_SVD(DATA_PATH, SVD_PATH, OUTPUT_PATH) 
else:
    raise ValueError("Unknown DIM_REDUCTION")

In [None]:
# TODO add this to config
# define model parameters
INPUT_WIDTH, HIDDEN_SIZE, N_HIDDEN_LAYERS = [int(param) for param in FC_MODEL.split("_")]

# create FC model and load model state dict
FC_model = make_FC_model(
    latent_size=N_LATENT,
    input_width=INPUT_WIDTH, 
    hidden_size=HIDDEN_SIZE, 
    n_hidden_layers=N_HIDDEN_LAYERS
)
FC_model.load(join(FC_PATH, FC_MODEL + ".pt"))
FC_model.eval()

In [None]:
# load scaler used during model training
latent_scaler = pt.load(join(OUTPUT_PATH, "scaler.pt"))

# feed reduced and scaled dataset into DataWindow class to create TimeSeriesTensorDatasets
data_window = DataWindow(train=latent_scaler.scale(train_red), test=latent_scaler.scale(test_red), input_width=INPUT_WIDTH, pred_horizon=PRED_HORIZON)
_, target_idx = data_window.rolling_window(test_red.shape[1])
target_idx = target_idx.tolist()

test_windows = data_window.test_dataset

# Autoregressive Prediction

In [None]:
# initialize losses
latent_loss = []
orig_loss = []

# find index of input-target pair in dataset that predicts TIMESTEP
timestep_1_id = find_target_index_in_dataset(nested_list=target_idx, target_id=TIMESTEP_1)
timestep_2_id = find_target_index_in_dataset(nested_list=target_idx, target_id=TIMESTEP_2)
# FIXME fix the indexing problem
print(timestep_1_id)
print(timestep_2_id)
print(target_idx)

with pt.no_grad():
    inputs, targets = test_windows[timestep_1_id]
    # add batch dimension with unsqueeze(0)
    inputs = inputs.flatten().unsqueeze(0).to(device)
    targets = targets.unsqueeze(0).to(device)

    for step in range(PRED_HORIZON):
        # shift input sequence by one: add last prediction while discarding first input
        if step != 0:
            inputs = shift_input_sequence(orig_seq=inputs, new_pred=pred)

        # time-evolution (autoregressive)
        pred = FC_model(inputs)
        latent_loss.append(mse_loss(targets[:, :, step], pred))

        # re-scaling
        pred_rescaled = latent_scaler.rescale(pred)

        # expand to full space either by VAE or SVD
        if DIM_REDUCTION == "VAE":
            # forward pass through decoder
            pred_orig = decoder(pred_rescaled.unsqueeze(0)).squeeze().detach() 
        else:
            # matrix multiplication with U, followed by adding back the temporal mean
            pred_orig = (U @ pred_rescaled.permute(1, 0) + test_data_orig.flatten(0, 1).mean(dim=1).unsqueeze(-1)).squeeze().unflatten(dim=0, sizes=config.target_resolution)
        
        print(pred_orig.shape)

        orig_loss.append(mse_loss(test_data_orig[:, :, target_idx[timestep_1_id][step]], pred_orig))

MSE = (test_data_orig[:, :, TIMESTEP_1] - pred_orig)**2

#### Plot Latent vs. Full Space Loss

In [None]:
fig = plt.subplots(1, 1, figsize=config.standard_figsize_1)
plt.plot(range(1, PRED_HORIZON + 1), latent_loss, label="reduced space loss")
plt.plot(range(1, PRED_HORIZON + 1), orig_loss, label="full space loss")
plt.ylabel("MSE")
plt.xlabel("number of autoregressive predictions")
plt.yscale("log")
plt.legend()
plt.tight_layout
plt.savefig(join(OUTPUT_PATH, f"{DIM_REDUCTION}_FC_predhor{PRED_HORIZON}_origvslatentloss.png"), bbox_inches="tight")

#### Plot Original vs. Predicted

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
vmin_cp, vmax_cp = config.plot_lims_cp
vmin_MSE, vmax_MSE = config.plot_lims_MSE_reconstruction
levels_cp = pt.linspace(vmin_cp, vmax_cp, 120)
levels_MSE = pt.linspace(vmin_MSE, vmax_MSE, 120)

ax1.contourf(xx, yy, test_data_orig[:, :, TIMESTEP_1], vmin=vmin_cp, vmax=vmax_cp, levels=levels_cp)
ax2.contourf(xx, yy, pred_orig, vmin=vmin_cp, vmax=vmax_cp, levels=levels_cp)
cont = ax3.contourf(xx, yy, MSE, vmin=vmin_MSE, vmax=vmax_MSE, levels=levels_MSE)

ax1.set_title("Ground Truth")
ax2.set_title("CNN-VAE-FC")

fig.subplots_adjust(right=0.95)
cax = fig.add_axes([0.99, 0.283, 0.03, 0.424])
cbar = fig.colorbar(cont, cax=cax,label = "Squarred Error")
cbar.formatter = ticker.FormatStrFormatter(f'%.{2}f')

for ax in [ax1, ax2, ax3]:
    ax.set_aspect("equal")
    ax.set_xticklabels([])
    ax.set_yticklabels([])

fig.savefig(join(OUTPUT_PATH, f"{DIM_REDUCTION}_FC_predhor{PRED_HORIZON}_timestep_reconstr.png"), bbox_inches="tight")