In [None]:
import sys
import os
repo_dir = os.path.dirname(os.getcwd())
sys.path.append(repo_dir)

In [None]:
from models.cm import BernoulliDecoder, CLTBernoulliDecoder, ContinuousMixture
from utils.bins_samplers import GaussianQMCSampler
from deeprob.spn.structure.cltree import BinaryCLT
from utils.reproducibility import seed_everything
from models.nets import get_decoder_debd
from torch.utils.data import DataLoader
from utils.datasets import load_debd
import pytorch_lightning as pl
import numpy as np
import torch


device = 'cuda' if torch.cuda.is_available() else 'cpu'
gpus = None if device == 'cpu' else 1
print(device)

## Hyper-parameters below

In [None]:
max_epochs = 100
batch_size = 128
latent_dim = 16
n_layers = 6
use_clt = False

# every element of the list is a run
n_bins_list = [2**14]
print(n_bins_list)

## Load datasets

In [None]:
train, valid, _ = load_debd('binarized_mnist')
n_features = train.shape[1]
print('Shape training: ', train.shape, ' Shape valid', valid.shape)

# Create data loader
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, drop_last=True)
valid_loader = DataLoader(valid, batch_size=batch_size)
print('Length training loader: ', len(train_loader), ' Length valid loader:', len(valid_loader))

## Train

In [None]:
# if you run OOM, use n_chunks (e.g. n_chunks = 32)
n_chunks = 32

for n_bins in n_bins_list:
    
        seed_everything(42)
        if use_clt:
            scope = list(range(n_features))
            domain = [[0, 1]] * n_features
            clt = BinaryCLT(scope, root=torch.randint(n_features, (1,)).item())
            clt.fit(train, domain, alpha=0.01)
            decoder = CLTBernoulliDecoder(
                net=get_decoder_debd(
                    latent_dim=latent_dim,
                    out_features=n_features*2,
                    n_layers=n_layers,
                    batch_norm=True),
                tree=list(clt.tree)
            )
        else:
            decoder = BernoulliDecoder(
                net=get_decoder_debd(
                    latent_dim=latent_dim,
                    out_features=n_features,
                    n_layers=n_layers,
                    batch_norm=True)
            )
            
        model = ContinuousMixture(
            decoder=decoder,
            sampler=GaussianQMCSampler(latent_dim, n_bins)
        )
        model.n_chunks = n_chunks
        model.missing = False
        
        cp_best_model_valid = pl.callbacks.ModelCheckpoint(
            save_top_k=1,
            monitor='valid_loss_epoch',
            mode='min',
            filename='best_model_valid-{epoch}'
        )
        early_stop_callback = pl.callbacks.early_stopping.EarlyStopping(
            monitor="valid_loss_epoch", 
            min_delta=0.00, 
            patience=15, 
            verbose=False,
            mode='min'
        )
        callbacks = [cp_best_model_valid, early_stop_callback]
        
        logger = pl.loggers.TensorBoardLogger(repo_dir + '/logs/bmnist/', 'cm_clt/' if use_clt else 'cm_fact/')
        trainer = pl.Trainer(
            max_epochs=max_epochs,
            gpus=gpus,
            callbacks=callbacks,
            logger=logger,
            deterministic=True
        )
        trainer.fit(model, train_loader, valid_loader)