In [None]:
import sys
import os
repo_dir = os.path.dirname(os.getcwd())
sys.path.append(repo_dir)

In [None]:
from models.cm import ContinuousMixture, GaussianDecoder
from torchvision.datasets import MNIST, FashionMNIST
from utils.reproducibility import seed_everything
from utils.datasets import UnsupervisedDataset
import torchvision.transforms as transforms
from models.nets import mnist_conv_decoder
from torch.utils.data import DataLoader
import numpy as np
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
gpus = None if device == 'cpu' else 1
print(device)

## Choose the dataset

In [None]:
# dataset, dataset_name = FashionMNIST, 'fashion_mnist'
dataset, dataset_name = MNIST, 'mnist'

In [None]:
transf = transforms.Compose([transforms.ToTensor()])
batch_size = 128

test = UnsupervisedDataset(dataset(root=repo_dir + '/data/', train=False, download=True, transform=transf))
test_loader = DataLoader(test, batch_size=batch_size)

## Load model (you should specify a path!)

In [None]:
# Warning: The model should be Conv based

path = ''
model = ContinuousMixture.load_from_checkpoint(path).to(device)
model.missing = False
model.eval();

## Compute LLs

In [None]:
# if you run OOM, you may want to use n_chunks
model.n_chunks = 32
n_bins_list = [2**7, 2**8, 2**9, 2**10, 2**11, 2**12, 2**13, 2**14]
            
for n_bins in n_bins_list:
    model.sampler.n_bins = n_bins
    z, log_w = model.sampler(seed=42)

    print('Computing test LL using %d bins..' % n_bins)
    print(model.eval_loader(test_loader, z, log_w, device=device).mean().item())

## Sample from CMs

In [None]:
import matplotlib.pyplot as plt
import torchvision

In [None]:
latent_dim = model.sampler.latent_dim
samples = model.decoder.net(torch.randn(16, latent_dim).to(device)).detach().cpu()

grid_img = torchvision.utils.make_grid(samples.view(16, 1, 28, 28), nrow=4)
fig, ax = plt.subplots(1, figsize=(10, 10))
plt.imshow(grid_img.permute(1, 2, 0));
ax.set_yticklabels([]);
ax.set_xticklabels([]);