# Convolutional Variational Autoencoder (CNN-VAE)

In [None]:
import sys
import os
from os.path import join
parent_dir = os.path.abspath(join(os.getcwd(), os.pardir))
app_dir = join(parent_dir, "app")
if app_dir not in sys.path:
      sys.path.append(app_dir)

from pathlib import Path
import torch as pt
from CNN_VAE import ConvDecoder, ConvEncoder, Autoencoder
import utils.config as config
import matplotlib.pyplot as plt

plt.rcParams["figure.dpi"] = 180

# use GPU if possible
device = pt.device("cuda:0") if pt.cuda.is_available() else pt.device("cpu")
print(device)

DATA_PATH = Path(os.path.abspath('')).parent / "data"
OUTPUT_PATH = Path(os.path.abspath('')).parent / "output" / "VAE" /"parameter_study"

test_case_name = "256_batchnorm_lr1e-4_Plateau_f0.8nosquash"

latent_size = 256

#### Initialize Autoencoder and additional parameters

In [None]:
# function to create VAE model
def make_VAE_model(n_latent: int = 256) -> pt.nn.Module:
    encoder = ConvEncoder(
        in_size=config.target_resolution,
        n_channels=config.input_channels,
        n_latent=n_latent,
        variational=True,
        layernorm=True
    )

    decoder = ConvDecoder(
        in_size=config.target_resolution,
        n_channels=config.output_channels,
        n_latent=n_latent,
        layernorm=True,
        squash_output=True
    )

    autoencoder = Autoencoder(encoder, decoder)
    autoencoder.to(device)
    return autoencoder

#### Load datasets and initialize dataloaders

In [None]:
test_dataset = pt.load(join(DATA_PATH, "test_dataset.pt"))

In [None]:
# load model
autoencoder = make_VAE_model(latent_size)
autoencoder.load(join(OUTPUT_PATH, "test"))
autoencoder.eval()

# load results
test_result = pt.load(join(OUTPUT_PATH, "test_results.pt"))

#### Plot loss over epochs

In [None]:
plt.plot(test_result["epoch"], test_result["train_loss"], lw=1, label="training")
plt.plot(test_result["epoch"], test_result["val_loss"], lw=1, label="validation")
# plt.plot(test_result["epoch"], test_result["test_loss"], lw=1, label="testing")
plt.yscale("log")
plt.xlim(0, config.epochs)
plt.xlabel("epoch")
plt.ylabel("MSE")
plt.legend()
plt.tight_layout()
plt.savefig(join(OUTPUT_PATH, "LOSS_" + test_case_name + ".png"))

#### Make test predictions

In [None]:
# load coordinates
coords = pt.load(join(DATA_PATH, "coords_interp.pt"))
xx, yy = coords

In [None]:
def make_prediction(model, image):
    with pt.no_grad():
        return model(image.unsqueeze(0)).squeeze(0).squeeze(0).detach()

In [None]:
autoencoder.eval()
fig, axes = plt.subplots(4, 2, figsize=(4, 10))
vmin, vmax = -1, 1
levels = pt.linspace(vmin, vmax, 120)

for i, row in enumerate(axes):
    if i == 0:
          row[0].set_title("Original")
          row[1].set_title("Encoded-Decoded")

    row[0].contourf(xx, yy, test_dataset[i*5].squeeze(0), vmin=vmin, vmax=vmax, levels = levels, extend="both")
    row[1].contourf(xx, yy, make_prediction(autoencoder, test_dataset[i]), vmin=vmin, vmax=vmax, levels = levels, extend="both")
    row[0].set_ylabel("Test Image {}".format(i))

    for ax in row:
            ax.set_aspect("equal")
            ax.set_xticklabels([])
            ax.set_yticklabels([])
fig.tight_layout()
plt.savefig(join(OUTPUT_PATH, "RECONSTR_" + test_case_name + ".png"))