# Visualizations and Samples

In [32]:
import dill
from fashion_mnist_vae.utils import utils, constants

from torchvision.datasets import FashionMNIST
import numpy as np
import torch
import pyro

## Load Models

In [33]:
vae_path = constants.ASSETS_DIR.joinpath("vae", "model.pkl")
vae_con_path = constants.ASSETS_DIR.joinpath("vae_con", "model.pkl")

with vae_path.open("rb") as f:
    vae = dill.load(f)

## Sample from Latent Space

In [67]:
rng = np.random.default_rng()
noise = rng.uniform(0, 1, (25, 1, 28, 28))
sampler = pyro.infer.Predictive(model=vae.model, guide=vae.guide, num_samples=1)
samples = sampler.forward(torch.tensor(noise, dtype=torch.float32))["latent_space"].reshape(25, 256)
decoded_latent_samples = vae.decoder(samples)

In [94]:
ds = FashionMNIST(train=False, transform=utils.to_tensor, root=constants.ASSETS_DIR)
x = torch.stack([ds[i][0] for i in range(25)], dim=0)

In [95]:
z_loc, z_std = vae.encoder(torch.tensor(x, dtype=torch.float32).to("cuda"))
z = pyro.distributions.Normal(z_loc, z_std).sample()
decoded = vae.decoder(z).detach().cpu().numpy().reshape(25, 28, 28) * 255
decoded.shape

  z_loc, z_std = vae.encoder(torch.tensor(x, dtype=torch.float32).to("cuda"))


(25, 28, 28)

## Visualize Examples

In [96]:
actuals = utils.image_grid(x.detach().cpu().numpy().reshape(25, 28, 28) * 255, rows=5, cols=5)
actuals.show()

In [97]:
image_grid = utils.image_grid(decoded, rows=5, cols=5)
image_grid.show()

### Visualize Conditional Examples