In [None]:
import sys
import os
repo_dir = os.path.dirname(os.getcwd())
sys.path.append(repo_dir)

In [None]:
from models.cm import BernoulliDecoder, CLTBernoulliDecoder, ContinuousMixture
from utils.bins_samplers import GaussianQMCSampler
from deeprob.spn.structure.cltree import BinaryCLT
from utils.reproducibility import seed_everything
from models.nets import get_decoder_debd
from torch.utils.data import DataLoader
from utils.datasets import load_debd
import pytorch_lightning as pl
import torch


device = 'cuda' if torch.cuda.is_available() else 'cpu'
gpus = None if device == 'cpu' else 1
print(device)

## Specify the datasets to train

In [None]:
DEBD_DATASETS = [
    'nltcs',
    'msnbc',
    'kdd',
    'plants',
    'baudio',
    'jester',
    'bnetflix',
    'accidents',
    'tretail',
    'pumsb_star',
    'dna',
    'kosarek',
    'msweb',
    'book',
    'tmovie',
    'cwebkb',
    'cr52',
    'c20ng',
    'bbc',
    'ad',
]
print(DEBD_DATASETS)

## Specify the random seeds below: every seed is a run!

In [None]:
seeds = [0, 1, 2, 3, 4]
print(seeds)

## All hyper-parameters below

In [None]:
batch_size = 128
n_layers = 6
latent_dim = 4
n_bins = 1024
max_epochs = 200
use_clt = True

## Train

In [None]:
# if you run OOM, use n_chunks (e.g. n_chunks = 64)
n_chunks = None
missing = False

for dataset_name in DEBD_DATASETS:
    
    
    train, valid, _ = load_debd(dataset_name)
    n_features = train.shape[1]
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, drop_last=True)
    valid_loader = DataLoader(valid, batch_size=batch_size, shuffle=False, drop_last=False)
    print(dataset_name, train.shape, valid.shape, len(train_loader), len(valid_loader))
    
    for seed in seeds:
        seed_everything(seed)
        
        if use_clt:
            scope = list(range(n_features))
            domain = [[0, 1]] * n_features
            clt = BinaryCLT(scope, root=torch.randint(n_features, (1,)).item())
            clt.fit(train, domain, alpha=0.01)
            decoder = CLTBernoulliDecoder(
                net=get_decoder_debd(
                    latent_dim=latent_dim,
                    out_features=n_features*2,
                    n_layers=n_layers,
                    batch_norm=True),
                tree=list(clt.tree)
            )
        else:
            decoder = BernoulliDecoder(
                net=get_decoder_debd(
                    latent_dim=latent_dim,
                    out_features=n_features,
                    n_layers=n_layers,
                    batch_norm=True)
            )
        
        model = ContinuousMixture(
            decoder=decoder,
            sampler=GaussianQMCSampler(latent_dim, n_bins)
        )
        model.n_chunks = n_chunks
        model.missing = missing
        
        cp_best_model_valid = pl.callbacks.ModelCheckpoint(
            save_top_k=1,
            monitor='valid_loss_epoch',
            mode='min',
            filename='best_model_valid-{epoch}'
        )
        early_stop_callback = pl.callbacks.early_stopping.EarlyStopping(
            monitor="valid_loss_epoch", 
            min_delta=0.00, 
            patience=15, 
            verbose=False,
            mode='min'
        )
        callbacks = [cp_best_model_valid, early_stop_callback]
        

        log_dir = repo_dir + ('/logs/debd/cm_clt/' if use_clt else '/logs/debd/cm_fact/')
        logger = pl.loggers.TensorBoardLogger(log_dir, name=dataset_name)
        trainer = pl.Trainer(
            max_epochs=max_epochs,
            gpus=gpus,
            callbacks=callbacks,
            logger=logger,
            deterministic=True
        )
        trainer.fit(model, train_loader, valid_loader)