In [None]:
import sys
import os
repo_dir = os.path.dirname(os.getcwd())
sys.path.append(repo_dir)

In [None]:
from utils.reproducibility import seed_everything
from models.lo import fast_bins_lo, bins_lo
from models.cm import ContinuousMixture
from torch.utils.data import DataLoader
from utils.datasets import load_debd
import numpy as np
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
gpus = None if device == 'cpu' else 1
print(device)

### Specify the datasets to evaluate

In [None]:
DEBD_DATASETS = [
    'nltcs',
    'msnbc',
    'kdd',
    'plants',
    'baudio',
    'jester',
    'bnetflix',
    'accidents',
    'tretail',
    'pumsb_star',
    'dna',
    'kosarek',
    'msweb',
    'book',
    'tmovie',
    'cwebkb',
    'cr52',
    'c20ng',
    'bbc',
    'ad',
]
print(DEBD_DATASETS)

### Specify the number of bins to use and number of epochs

In [None]:
n_bins_list = [2**7, 2**8, 2**9, 2**10]
n_epochs = 150
print(n_bins_list, n_epochs)

## Train and Evaluate

In [None]:
def evaluate_lls_dict(lls_dict):
    for n_bins in lls_dict.keys():
        avg_lls_per_run = [np.mean(ll) for ll in lls_dict[n_bins]]
        avg_ll = np.mean(avg_lls_per_run)
        std_ll = np.std(avg_lls_per_run)
        print('Evaluating using ' + str(n_bins) + ' bins..')
        print('AVG LL: %f ' % avg_ll + ' STD LL: %f ' % std_ll)
        print('Latex string: %.2f$\\pm$%.2f' % (avg_ll, std_ll))

In [None]:
# if you run OOM you can tweak n_chunks and batch_size
n_chunks = None
batch_size = 128
cm_clt = True

for dataset_name in DEBD_DATASETS:
    
    train, valid, test = load_debd(dataset_name)
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, drop_last=True)
    valid_loader = DataLoader(valid, batch_size=batch_size)
    test_loader = DataLoader(test, batch_size=batch_size)
    print('Evaluating ' + dataset_name + '..\n')

    test_lls_dict = {n_bins: [] for n_bins in n_bins_list}
    
    log_dir = repo_dir + ('/logs/debd/cm_clt/' if cm_clt else '/logs/debd/cm_fact/')
    folder_tree = list(os.walk(log_dir + dataset_name))
    for n_bins in n_bins_list:
        for folder in folder_tree:
            if 'checkpoints' in folder[0]:
                for ckpt in folder[2]:
                    
                    model = ContinuousMixture.load_from_checkpoint(folder[0] + '/' + ckpt).to(device)
                    model.n_chunks = n_chunks
                    model.missing = False
                    
                    if 'best_model_valid' in ckpt:
                        seed_everything(42)
                        z, log_w = bins_lo(model, n_bins, train_loader, valid_loader, max_epochs=n_epochs, lr=1e-3, device=device)
                        test_lls_dict[n_bins].append(
                            model.eval_loader(test_loader, z, log_w, device=device).cpu().numpy())
                        
    evaluate_lls_dict(test_lls_dict)
    print('\nLO ended on ' + dataset_name + '\n')
    print('---------------------------------------------------------------------------\n')