In [None]:
import sys
import os
repo_dir = os.path.dirname(os.getcwd())
sys.path.append(repo_dir)

In [None]:
from utils.bins_samplers import GaussianQMCSampler
from utils.reproducibility import seed_everything
from models.mixtures import BernoulliMixture
from torch.utils.data import DataLoader
from utils.datasets import load_debd
import pytorch_lightning as pl
from models.vae import VAE
import numpy as np
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
gpus = None if device == 'cpu' else 1
print(device)

## Specify the datasets to evaluate

In [None]:
BINARY_DATASETS = [
    'nltcs',
    'msnbc',
    'kdd',
    'plants',
    'baudio',
    'jester',
    'bnetflix',
    'accidents',
    'tretail',
    'pumsb_star',
    'dna',
    'kosarek',
    'msweb',
    'book',
    'tmovie',
    'cwebkb',
    'cr52',
    'c20ng',
    'bbc',
    'ad',
]
print(BINARY_DATASETS)

## Specify the integration points (bins)

In [None]:
n_bins_list = [2**7, 2**8, 2**9, 2**10, 2**11, 2**12, 2**13]
keys = ['elbo'] + n_bins_list
print(keys)

In [None]:
test_sampler = GaussianQMCSampler(latent_dim=4, n_bins=n_bins_list[0])

def evaluate_mixture(model, loader, z, log_w, device):
    logits_p = model.vae.decoder(z.to(device)).logit()
    mixture = BernoulliMixture(logits_p=logits_p, logits_w=log_w).to(device)
    lls = []
    for x in loader:
        lls.extend(list(mixture(x.to(device)).detach().cpu().numpy()))
    assert len(lls) == len(loader.dataset)
    return lls

def evaluate_elbo(model, loader, n_mc_samples, device):
    lls = []
    model.eval();
    for x in loader:
        lls.extend(model.log_prob(x.to(device), n_mc_samples).detach().cpu().numpy())
    assert len(lls) == len(loader.dataset)
    return lls

def evaluate_lls_dict(lls_dict):
    for key in lls_dict.keys():
        avg_lls_per_run = [np.mean(ll) for ll in lls_dict[key]]
        avg_ll = np.mean(avg_lls_per_run)
        std_ll = np.std(avg_lls_per_run)
        if key == 'elbo':
            print('Evaluating using ELBO..')
        else:
            print('Evaluating using ' + str(key) + ' bins..')
        print('AVG LL: %f ' % avg_ll + ' STD LL: %f ' % std_ll)
        print('Latex string: %.2f$\\pm$%.2f' % (avg_ll, std_ll))

In [None]:
# if case of OOM issues, decrease the batch size
batch_size = 128
only_test = True

n_elbo_mc_samples = 1_000
seed_everything(42)

for dataset_name in BINARY_DATASETS:
    
    _, valid, test = load_debd(dataset_name)
    if not only_test:
        valid_loader = DataLoader(valid, batch_size=batch_size)
    test_loader = DataLoader(test, batch_size=batch_size)
    print('Evaluating ' + dataset_name + '..')
    
    if not only_test:
        bmv_valid_lls_dict = {key: [] for key in keys}    
    bmv_test_lls_dict = {key: [] for key in keys}
        
    exp_runs = 0
    for folder in list(os.walk(repo_dir + '/logs/debd/vae/' + dataset_name)):
        
        if 'checkpoints' in folder[0]:
            exp_runs += 1
            for ckpt in folder[2]:
                model = VAE.load_from_checkpoint(folder[0] + '/' + ckpt).to(device)
                for key in keys:
                    if isinstance(key, int):
                        test_sampler.n_bins = key
                        z, log_w = test_sampler(seed=42)

                    if 'best_model_valid' in ckpt:
                        if key == 'elbo':
                            if not only_test:
                                bmv_valid_lls_dict[key].append(
                                    evaluate_elbo(model, valid_loader, n_elbo_mc_samples, device))
                            bmv_test_lls_dict[key].append(
                                evaluate_elbo(model, test_loader, n_elbo_mc_samples, device))
                        else:
                            if not only_test:
                                bmv_valid_lls_dict[key].append(
                                    evaluate_mixture(model, valid_loader, z, log_w, device))
                            bmv_test_lls_dict[key].append(
                                evaluate_mixture(model, test_loader, z, log_w, device))
                            
    if not only_test:
        print('\n --- BMV on VALID ---')
        evaluate_lls_dict(bmv_valid_lls_dict)
    print('\n --- BMV on TEST ---')
    evaluate_lls_dict(bmv_test_lls_dict)
    
    print('\n' + str(exp_runs) + ' runs found and evaluated for ' + dataset_name + '\n\n')
    print('---------------------------------------------------------------------------\n')