In [None]:
import sys
import os
repo_dir = os.path.dirname(os.getcwd())
sys.path.append(repo_dir)

In [None]:
from utils.reproducibility import seed_everything
from torch.utils.data import DataLoader
from utils.datasets import load_debd
from models.vae import VAE, DebdVAE
import pytorch_lightning as pl
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
gpus = None if device == 'cpu' else 1
print(device, gpus)

### Specify the datasets to train

In [None]:
DEBD_DATASETS = [
    'nltcs',
    'msnbc',
    'kdd',
    'plants',
    'baudio',
    'jester',
    'bnetflix',
    'accidents',
    'tretail',
    'pumsb_star',
    'dna',
    'kosarek',
    'msweb',
    'book',
    'tmovie',
    'cwebkb',
    'cr52',
    'c20ng',
    'bbc',
    'ad',
]
print(DEBD_DATASETS)

### Specify the random seeds below: every seed is a run!

In [None]:
seeds = [0, 1, 2, 3, 4]

In [None]:
batch_size = 128
n_layers = 6
latent_dim = 4
max_epochs = 200

## Train

In [None]:
for dataset_name in DEBD_DATASETS:
    
    train, valid, _ = load_debd(dataset_name)
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, drop_last=True)
    valid_loader = DataLoader(valid, batch_size=batch_size, shuffle=False, drop_last=False)
    print(dataset_name, train.shape, valid.shape, len(train_loader), len(valid_loader))

    for seed in seeds:
        
        seed_everything(seed)
        model = VAE(
            vae=DebdVAE(
                n_features=train.shape[1],
                latent_dim=latent_dim,
                batch_norm=True,
                n_layers=n_layers),
            recon_loss=torch.nn.BCELoss(reduction='none')
        )
        
        cp_best_model_valid = pl.callbacks.ModelCheckpoint(
            save_top_k=1,
            monitor='valid_loss_epoch',
            mode='min',
            filename='best_model_valid-{epoch}'
        )
        early_stop_callback = pl.callbacks.early_stopping.EarlyStopping(
            monitor="valid_loss_epoch", 
            min_delta=0.00, 
            patience=15, 
            verbose=False,
            mode='min'
        )
        callbacks = [cp_best_model_valid, early_stop_callback]

        logger = pl.loggers.TensorBoardLogger(repo_dir + '/logs/debd/vae/', name=dataset_name)
        trainer = pl.Trainer(
            max_epochs=max_epochs,
            gpus=gpus,
            callbacks=callbacks,
            logger=logger,
            deterministic=True
        )
        trainer.fit(model, train_loader, valid_loader)