In [None]:
import sys
import os
repo_dir = os.path.dirname(os.getcwd())
sys.path.append(repo_dir)

In [None]:
from utils.reproducibility import seed_everything
from models.mixtures import BernoulliMixture
from torch.utils.data import DataLoader
from utils.datasets import load_debd
from tqdm import tqdm
import numpy as np
import torch
import copy

device = 'cuda' if torch.cuda.is_available() else 'cpu'
gpus = None if device == 'cpu' else 1
print(device)

In [None]:
dataset_name = 'nltcs'
batch_size = 128

In [None]:
train, valid, test = load_debd(dataset_name)
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, drop_last=True)
valid_loader = DataLoader(valid, batch_size=batch_size, drop_last=True)
print(dataset_name, train.shape, valid.shape, test.shape)

## Instantiate mixture

In [None]:
seed_everything(42)
n_components = 1024
model = BernoulliMixture(
    logits_p=torch.randn(n_components, train.shape[1]),
    logits_w=torch.full((n_components,), 1 / n_components),
    learn_w=False
).to(device)
opt = torch.optim.Adam(params=model.parameters(), lr=1e-3, weight_decay=1e-5)

## Train

In [None]:
max_num_epochs = 150
early_stopping_epochs = 30
warmup = 30

best_model = model
best_loss = np.inf
e = 0

for epoch in range(max_num_epochs):
    model.train()
    train_loss_avg = []
    for x in train_loader:
        opt.zero_grad()
        loss = -model(x.to(device)).mean()
        loss.backward()
        opt.step()
        train_loss_avg.append(loss.item())

    model.eval()
    valid_loss_avg = []
    with torch.no_grad():
        for x in valid_loader:
            loss = -model(x.to(device)).mean()
            valid_loss_avg.append(loss.item())
    val_loss_epoch = np.mean(valid_loss_avg)
    
    # early-stopping
    if val_loss_epoch < best_loss:
        e = 0
        best_loss = val_loss_epoch
        best_model = copy.deepcopy(model)
        best_model_epoch = epoch
    else:
        e += 1
        if epoch < warmup:
            e = 0
        if e > early_stopping_epochs:
            break

    print('Epoch [%d / %d] Training loss: %f Validation Loss: %f e: %d' % 
          (epoch + 1, max_num_epochs, np.mean(train_loss_avg), val_loss_epoch, e))

print('Best model epoch: ', best_model_epoch)

## Test

In [None]:
# if you use a high number number of bins then you may want to decrease the batch size
test_loader = DataLoader(test, batch_size=16, drop_last=False)

test_ll = []
model.eval()
for x in tqdm(test_loader):
    test_ll.extend(list(model(x.to(device)).detach().cpu().numpy()))
assert len(test_ll) == test.shape[0]
print('Test LL: %.2f' % np.mean(test_ll))