In [6]:
from data import make_pinwheel_data
from vae import resVAE

import torch

from plot.gmm_plot import plot_reconstruction
from vae.models.vae import reparameterize

In [2]:
data_parameters = {
    "radial_std": 0.3,
    "tangential_std": 0.05,
    "num_classes": 5,
    "num_per_class": 100,
    "rate": 0.25,
}

vae_parameters = {
    "latent_dim": 2,
    "input_size": 2,
    "hidden_size": [50],
    "name": "resvae",
    "recon_loss": "likelihood",
    "weight_init_std": 1e-2,
}

vae_train_parameters = {"epochs": 500, "batch_size": 32, "kld_weight": 0.1}

hyperparameters = {
    "VAE_parameters": vae_parameters,
    "VAE_train_parameters": vae_train_parameters,
    "pinwheel_data_parameters": data_parameters,
}

In [8]:
# get data and vae model
obs = make_pinwheel_data(**hyperparameters["pinwheel_data_parameters"])
model = resVAE(**hyperparameters["VAE_parameters"])
model.load_model(path="../trained_vae", affix="trained")

In [16]:
# sanity check if everything is loaded and trained correctly
data = torch.tensor(obs).to(model.device).double()

mu_z, log_var_z = model.encode(data)
z = reparameterize(mu_z, log_var_z)
mu_x, log_var_x = model.decode(z)

fig = plot_reconstruction(
    obs=obs,
    mu=mu_x.cpu().detach().numpy(),
    latent=z.cpu().detach().numpy(),
)
fig.show()

save plot to plot.pdf


In [28]:
z = reparameterize(mu_z, log_var_z)
z = z.cpu().detach().numpy()

mask = (z[:, 0] > 0) & (z[:, 1] > 0)

z[mask] += [60, 60]

mu_x, log_var_x = model.decode(torch.tensor(z))

fig = plot_reconstruction(
    obs=obs,
    mu=mu_x.cpu().detach().numpy(),
    latent=z,
)
fig.show()

save plot to plot.pdf
