In [26]:
import torch
import numpy as np

from vae import resVAE
from data import make_pinwheel_data
from svae.gmm import SVAE

from svae.gmm.local_optimization import local_optimization

from distributions import NormalInverseWishart, Dirichlet

from distributions import categorical, gaussian

from plot.gmm_plot import plot_latent_space, plot_observed_space

from matrix_ops import unpack_dense
from distributions import Gaussian

from scipy.stats import multivariate_normal

In [27]:
hyperparameters = {
    "VAE_parameters": {
        "latent_dim": 2,
        "input_size": 2,
        "hidden_size": [40],
        "recon_loss": "likelihood",
        "name": "resvae",
        "weight_init_std": 1e-2
    }, "pinwheel_data_parameters": {
        "radial_std": 0.3,
        "tangential_std": 0.05,
        "num_classes": 5,
        "num_per_class": 100,
        "rate": 0.25,
    },
}

In [28]:
# generate synthetic data
data = make_pinwheel_data(**hyperparameters["pinwheel_data_parameters"])

In [29]:
# get recognition network
network = resVAE(**hyperparameters["VAE_parameters"])

# get svae model
model = SVAE(network)
model.load_model(path="../../trained_gmm", epoch="trained")

In [30]:
# sanity check if everything is loaded and trained correctly
potentials = model.encode(torch.tensor(data).to(model.vae.device).double())

x, eta_x, label_stats, _, _ = local_optimization(potentials, model.eta_theta)

# get encoded means
gaussian_stats = Gaussian(eta_x).expected_stats()
_, Ex, _, _ = unpack_dense(gaussian_stats)

# get reconstructions
mu_y, log_var_y = model.decode(x)

latent_fig = plot_latent_space(
    latent=Ex.cpu().detach().numpy(),
    eta_theta=model.eta_theta,
    title="plot_latent",
    y_axes=[-100, 100],
    x_axes=[-100, 100]
)
latent_fig.show()

observation_fig = plot_observed_space(
    obs=data,
    mu=mu_y.squeeze().cpu().detach().numpy(),
    classes=torch.argmax(label_stats, dim=-1).cpu().detach().numpy(),
    title="plot_obs",
    y_axes=[-20, 20],
    x_axes=[-20, 20]
)
observation_fig.show()

save plot to plot_latent.pdf


save plot to plot_obs.pdf


In [31]:
def sample(loc, Sigma):
    return [multivariate_normal.rvs(l, S, size=1) for (l, S) in zip(loc, Sigma)]

In [32]:
# GENERATE NEW SAMPLES
dir_natparam, niw_natparam = model.eta_theta

# sample weights for the labels from global dirichlet
labels_weights = Dirichlet(dir_natparam).sample(500)

# get class assignments z using the weights
labels_one_hot = categorical.sample(labels_weights)
labels = np.argmax(labels_one_hot, axis=-1)

# for every data-point samples it's gaussian parameters, according to the class it belongs to
gaussian_parameter_samples = NormalInverseWishart(niw_natparam).sample(labels)

mu_samples, Sigma_samples = zip(*gaussian_parameter_samples)

# samples latent data points x
latent_samples = sample(np.array(mu_samples), np.array(Sigma_samples))

# decode the newly samples latents x into reconstruction y
recon, _ = model.decode(torch.tensor(latent_samples).double())

latent_fig = plot_latent_space(
    latent=latent_samples,
    eta_theta=model.eta_theta,
    title="baseline_latent",
    y_axes=[-100, 100],
    x_axes=[-100, 100]
)
latent_fig.show()

observation_fig = plot_observed_space(
    # obs=data,
    mu=recon.cpu().detach().numpy(),
    classes=labels,
    title="baseline_obs",
    y_axes=[-20, 20],
    x_axes=[-20, 20]
)
observation_fig.show()

save plot to baseline_latent.pdf


save plot to baseline_obs.pdf


In [33]:
# # GENERATE NEW SAMPLES
# dir_natparam, niw_natparam = model.eta_theta
#
# dir_natparam.requires_grad = False
#
# dir_natparam[12] = 0
# dir_natparam[2] = 0
# dir_natparam[8] = 0
#
# # sample weights for the labels from global dirichlet
# labels_weights = Dirichlet(dir_natparam).sample(1000)
# # labels_weights[:, 12] = 0
#
# # get class assignments z using the weights
# labels_one_hot = categorical.sample(labels_weights)
# labels = np.argmax(labels_one_hot, axis=-1)
#
# # for every data-point samples it's gaussian parameters, according to the class it belongs to
# gaussian_parameter_samples = NormalInverseWishart(niw_natparam).sample(labels)
#
# mu_samples, Sigma_samples = zip(*gaussian_parameter_samples)
#
# # samples latent data points x
# latent_samples = sample(np.array(mu_samples), np.array(Sigma_samples))
#
# # decode the newly samples latents x into reconstruction y
# recon, _ = model.decode(torch.tensor(latent_samples).double())
#
# latent_fig = plot_latent_space(
#     latent=latent_samples,
#     eta_theta=model.eta_theta,
#     title="remove_latent",
#     y_axes=[-100, 100],
#     x_axes=[-100, 100]
# )
# latent_fig.show()
#
# observation_fig = plot_observed_space(
#     # obs=data,
#     mu=recon.squeeze().cpu().detach().numpy(),
#     classes=labels,
#     title="remove_obs",
#     y_axes=[-20, 20],
#     x_axes=[-20, 20]
# )
# observation_fig.show()

In [34]:
# DO TRANSFORMATION

# unpack the paramaters
dir_natparam, niw_natparam = model.eta_theta
kappa, mu_0, Phi, nu = NormalInverseWishart(niw_natparam).natural_to_standard()

# move one of the clusters
mu_0[3] =+ torch.tensor([10, 0])

# pack the parameters
niw_natparam = NormalInverseWishart(niw_natparam).standard_to_natural(kappa, mu_0, Phi, nu)
eta_theta = (dir_natparam, niw_natparam)

# update the svae global parameters with the transformation
# model.eta_theta = eta_theta

# GENERATE NEW SAMPLES

# sample weights for the labels from global dirichlet
labels_weights = Dirichlet(dir_natparam).sample(500)

# get class assignments z using the weights
labels_one_hot = categorical.sample(labels_weights)
labels = np.argmax(labels_one_hot, axis=-1)

# for every data-point samples it's gaussian parameters, according to the class it belongs to
gaussian_parameter_samples = NormalInverseWishart(niw_natparam).sample(labels)

mu_samples, Sigma_samples = zip(*gaussian_parameter_samples)
# samples latent data points x
latent_samples = sample(np.array(mu_samples), np.array(Sigma_samples))

# decode the newly samples latents x into reconstruction y
recon, _ = model.decode(torch.tensor(latent_samples).double())

latent_fig = plot_latent_space(
    latent=latent_samples,
    eta_theta=eta_theta,
    title="move_latent",
    y_axes=[-100, 100],
    x_axes=[-100, 100]
)
latent_fig.show()

observation_fig = plot_observed_space(
    # obs=data,
    mu=recon.squeeze().cpu().detach().numpy(),
    classes=labels,
    title="move_obs",
    y_axes=[-20, 20],
    x_axes=[-20, 20]
)
observation_fig.show()

save plot to move_latent.pdf


save plot to move_obs.pdf
