In [None]:
from bvae import *

In [None]:
#%% # prerequisites
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions.multinomial import Multinomial
from torchvision import datasets, transforms
from scipy.interpolate import BSpline
import numpy as np
from torch.distributions import MultivariateNormal, Normal, RelaxedOneHotCategorical
from torch.utils.data import DataLoader, TensorDataset

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import random

import tqdm

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

DEVICE

## Data preparation

In [None]:
minst_dir = './'

train_dataset = datasets.MNIST(root=minst_dir,
                               train=True, transform=transforms.ToTensor(), download=True)

test_dataset = datasets.MNIST(root=minst_dir, 
                               train=False, transform=transforms.ToTensor(), download=False)

In [None]:
batch_size = 32

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

## Quick test

In [None]:
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

bvae = BVAE(784, hidden_dim=[512, 256], z_dim=8, device = DEVICE)
optimizer = optim.Adam(bvae.parameters())

# bvae.burn_in(train_loader)

for epoch in range(1, 11):
    bvae.train_model(epoch = epoch, train_loader = train_loader, T=5, coef_spline_penalty=10, optimizer = optimizer, temperature = 0.1)
    bvae.test_model(test_loader)

In [None]:
(data, label) = next(iter(test_loader))

In [None]:
recon_mean, recon_var, coef_spl, weights, \
    z_sample_approx, pdf_approx, z_std = bvae.forward(data.to(DEVICE).view(-1, 784).unsqueeze(0), 0.1)
sample = recon_mean.detach()
plt.imshow(  sample.cpu().view(batch_size, 1, 28, 28)[9].permute(1, 2, 0)  )
