# Convolutional Variational Autoencoder (CNN-VAE)

In [None]:
import sys
import os
from os.path import join
parent_dir = os.path.abspath(join(os.getcwd(), os.pardir))
app_dir = join(parent_dir, "app")
if app_dir not in sys.path:
      sys.path.append(app_dir)

from pathlib import Path
import torch as pt
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
from CNN_VAE import ConvDecoder, ConvEncoder, Autoencoder
from utils.training_loop import train_cnn_vae
import utils.config as config
from sklearn import metrics
import matplotlib.pyplot as plt

plt.rcParams["figure.dpi"] = 180

# use GPU if possible
device = pt.device("cuda:0") if pt.cuda.is_available() else pt.device("cpu")

DATA_PATH = Path(os.path.abspath('')).parent / "data"
OUTPUT_PATH = Path(os.path.abspath('')).parent / "output" / "VAE"

#### Initialize Autoencoder and additional parameters

In [None]:
# initialize CNN-VAE classes
encoder = ConvEncoder(
    in_size=config.target_resolution,
    n_channels=config.input_channels,
    n_latent=config.latent_size,
    batchnorm=True,
    variational=True
)

decoder = ConvDecoder(
    in_size=config.target_resolution,
    n_channels=config.output_channels,
    n_latent=config.latent_size,
    batchnorm=True,
    squash_output=True
)

autoencoder = Autoencoder(encoder, decoder)
autoencoder.to(device)

#### Load datasets and initialize dataloaders

In [None]:
train_dataset = pt.load(join(DATA_PATH, "train_dataset.pt"))
val_dataset = pt.load(join(DATA_PATH, "val_dataset.pt"))
test_dataset = pt.load(join(DATA_PATH, "test_dataset.pt"))

train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=True)

#### Initialize training metrics and objects

In [None]:
# optimizer
optimizer = pt.optim.Adam(autoencoder.parameters(), lr=config.learning_rate)

# checkpoint file
checkpoint_file = join(OUTPUT_PATH, "checkpoints")

# learning rate scheduler
scheduler = pt.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode="min", factor=0.2)
# mode="min" means that the lr will be reduced when the MSE has stopped decreasing
# factor states by which factor the lr will be reduced on stagnation

test_result = train_cnn_vae(
    model=autoencoder,
    loss_func=nn.MSELoss(),
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    epochs=config.epochs,
    optimizer=optimizer,
    lr_schedule=scheduler
)

#### Plot loss over epochs

In [None]:
plt.plot(test_result["epoch"], test_result["train_loss"], lw=1, label="training")
plt.plot(test_result["epoch"], test_result["val_loss"], lw=1, label="validation")
plt.plot(test_result["epoch"], test_result["test_loss"], lw=1, label="testing")
plt.yscale("log")
plt.xlim(0, config.epochs)
plt.xlabel("epoch")
plt.ylabel("MSE")
plt.legend()
plt.show()

#### Make test predictions

In [None]:
# load coordinates
coords = pt.load(join(DATA_PATH, "coords_interp.pt"))
xx, yy = coords

In [None]:
def make_prediction(model, image):
    return model(image.unsqueeze(0)).squeeze(0).squeeze(0).detach()

In [None]:
fig, axes = plt.subplots(2, 2)
vmin, vmax = -1, 1
levels = pt.linspace(vmin, vmax, 120)

for i, row in enumerate(axes):
    if i == 0:
          row[0].set_title("Original")
          row[1].set_title("Encoded-Decoded")

    row[0].contourf(xx, yy, test_dataset[i].squeeze(0), vmin=vmin, vmax=vmax, levels = levels, extend="both")
    row[1].contourf(xx, yy, make_prediction(autoencoder, test_dataset[i]), vmin=vmin, vmax=vmax, levels = levels, extend="both")
    row[0].set_ylabel("Test Image {}".format(i))

    for ax in row:
            ax.set_aspect("equal")
            ax.set_xticklabels([])
            ax.set_yticklabels([])
fig.tight_layout()