In [None]:
import sys
import os
repo_dir = os.path.dirname(os.getcwd())
sys.path.append(repo_dir)

In [None]:
from utils.bins_samplers import GaussianQMCSampler
from models.cm import ContinuousMixture
from torch.utils.data import DataLoader
from utils.datasets import load_debd
import pytorch_lightning as pl
import numpy as np
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
gpus = None if device == 'cpu' else 1
print(device)

## Load datasets and create dataloaders

In [None]:
# Load datasets
_, valid, test = load_debd('binarized_mnist')

# Create loaders - if you run OOM, you may want to decrease the batch size
batch_size = 64
valid_loader = DataLoader(valid, batch_size=batch_size)
test_loader = DataLoader(test, batch_size=batch_size)

## Load model (you should specify a path!)

In [None]:
# Warning: The model should be MLP based

path = ''
model = ContinuousMixture.load_from_checkpoint(path).to(device)
model.missing = False
model.eval();

## Evaluate LLs 

In [None]:
# if you run OOM, you may want to use n_chunks
model.n_chunks = 32
n_bins_list = [2**7, 2**8, 2**9, 2**10, 2**11, 2**12, 2**13, 2**14]
            
for n_bins in n_bins_list:
    model.sampler.n_bins = n_bins
    z, log_w = model.sampler(seed=42)

    print('Computing validation LL using %d bins..' % n_bins)
    print(model.eval_loader(valid_loader, z, log_w, device=device).mean().item())
    print('Computing test LL using %d bins..' % n_bins)
    print(model.eval_loader(test_loader, z, log_w, device=device).mean().item())

## Draw 16 samples from CMs of factorisations (Only works for CMs of factorisations)

In [None]:
import matplotlib.pyplot as plt
import torchvision

In [None]:
latent_dim = model.sampler.latent_dim
samples = model.decoder.net(torch.randn(16, latent_dim).to(device)).sigmoid().detach().cpu()

grid_img = torchvision.utils.make_grid(samples.view(16, 1, 28, 28), nrow=4)
fig, ax = plt.subplots(1, figsize=(10, 10))
plt.imshow(grid_img.permute(1, 2, 0));
ax.set_yticklabels([]);
ax.set_xticklabels([]);

## Compile a mixture and sample (Only works for CMs of factorisations)

In [None]:
from models.mixtures import BernoulliMixture

model.sampler.n_bins = n_components = 1024
z, log_w = model.sampler(seed=42)
mixture = BernoulliMixture(logits_p=model.decoder.net(z.to(device)), logits_w=log_w).to(device)

In [None]:
# try both return_p=True and return_p=False
samples = mixture.sample(16, return_p=False).detach().cpu()

grid_img = torchvision.utils.make_grid(samples.view(16, 1, 28, 28), nrow=4)
fig, ax = plt.subplots(1, figsize=(10, 10))
plt.imshow(grid_img.permute(1, 2, 0));
ax.set_yticklabels([]);
ax.set_xticklabels([]);

## Sample from CM of CLTs (Only works for CMs of CLTs)

In [None]:
from deeprob.spn.structure.cltree import BinaryCLT
clt = BinaryCLT(list(range(784)), tree=model.decoder.tree.numpy())
mask = [True, False] * 784

In [None]:
param = model.decoder.net(torch.randn(1, latent_dim).to(device)).sigmoid().detach().cpu()[0]
r = param.view(784, 2).repeat_interleave(2, 0)
r[mask] = 1 - r[mask]
r[clt.root * 2][1] = r[clt.root * 2][0]
r[clt.root * 2 + 1][1] = r[clt.root * 2 + 1][0]
clt.params = r.view(784, 2, 2).permute(0, 2, 1).log().numpy()

In [None]:
samples = torch.Tensor(clt.sample(np.nan * np.zeros((16, 784))))
grid_img = torchvision.utils.make_grid(samples.view(16, 1, 28, 28), nrow=4)
fig, ax = plt.subplots(1, figsize=(10, 10))
plt.imshow(grid_img.permute(1, 2, 0));
ax.set_yticklabels([]);
ax.set_xticklabels([]);