# Vartiational Inference for Concept Embeddings (VICE)

## Install dependencies

In [None]:
!pip install -r './requirements.txt'

In [None]:
import argparse
import os
import random
import torch
import utils
import visualization

import numpy as np

from model.vice import VICE
from typing import Tuple

## I/O

In [None]:
# specify input and output directories
triplets_dir = './path/to/triplets/'
results_dir = './path/to/results/'
plots_dir = './path/to/plots'

### Variables and hyperparameters

In [None]:
# define variables
task = 'odd_one_out' # 3AFC
modality = 'behavioral'
epochs = 1000 # maximum number of epochs
burnin = 500 # minimum number of epochs
latent_dim = 50 # initial latent dimensionality of VICE
batch_size = 128 # use power of 2 if you intend to perfrom model training on a GPU, else see what works best
optim = 'adam'
prior = 'gaussian' # spike-and-slab Gaussian mixture prior
mc_samples = 10
eta = 1e-3 # learning rate used in optimizer
spike = 0.1 # sigma_spike
slab = 1.0 # sigma_slab
pi = 0.5
k = 5 # minimum number of items that compose a latent dimension (according to importance scores)
ws = 200 # window size determines for how many epochs the number of latent causes (after pruning) is not allowed to vary
steps = 50
seed = 42
verbose = True

In [None]:
# seed random number generator
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)

# set device
if torch.cuda.is_available():
    device = torch.device('cuda')
    torch.cuda.manual_seed_all(seed)
else:
    # number of threads used for intraop parallelism on CPU; use only if device is CPU
    num_threads = 8
    torch.set_num_threads(num_threads)
    os.environ['OMP_NUM_THREADS'] = str(num_threads)
    device = torch.device('cpu')

### Create mini-batches of train and test triplets

In [None]:
# load train and test triplets into memory
train_triplets, test_triplets = utils.load_data(device=device, triplets_dir=triplets_dir)

# get number of trials and number of unique items in the data
N = train_triplets.shape[0]
n_items = utils.get_nitems(train_triplets)

# load mini-batches for training
train_batches, val_batches = utils.load_batches(
                                                train_triplets=train_triplets,
                                                test_triplets=test_triplets,
                                                n_items=n_items,
                                                batch_size=batch_size,
                                                inference=False,
)

### Create directories

In [None]:
# helper to create directories for storing results
def create_dirs(
                results_dir: str,
                plots_dir: str,
                modality: str,
                latent_dim: int,
                optim: str,
                prior: str,
                spike: float,
                slab: float,
                pi: float,
                rnd_seed: int,
) -> Tuple[str, str, str]:
    """Create directories for results, plots, and model parameters."""
    print('\n...Creating directories.\n')
    if results_dir == './results/':
        results_dir = os.path.join(
                                    results_dir,
                                    modality,
                                    f'{latent_dim}d',
                                    optim,
                                    prior,
                                    str(spike),
                                    str(slab),
                                    str(pi),
                                    f'seed{rnd_seed:02d}',
        )
    if not os.path.exists(results_dir):
        os.makedirs(results_dir, exist_ok=True)
        
    if plots_dir == './plots/':
        plots_dir = os.path.join(
                                 plots_dir,
                                 modality,
                                 f'{latent_dim}d',
                                 optim,
                                 prior,
                                 str(spike),
                                 str(slab),
                                 str(pi),
                                 f'seed{rnd_seed:02d}',
        )
    if not os.path.exists(plots_dir):
        os.makedirs(plots_dir, exist_ok=True)
        
    model_dir = os.path.join(results_dir, 'model')
    return results_dir, plots_dir, model_dir

In [None]:
results_dir, plots_dir, model_dir = create_dirs(
                                                results_dir=results_dir,
                                                plots_dir=plots_dir,
                                                modality=modality,
                                                latent_dim=latent_dim,
                                                optim=optim,
                                                prior=prior,
                                                spike=spike,
                                                slab=slab,
                                                pi=pi,
                                                rnd_seed=seed,
                                            )

## VICE optimization

In [None]:
# initialize VICE model
vice = VICE(
            task=task,
            n_train=N,
            n_items=n_items,
            latent_dim=latent_dim,
            optim=optim,
            eta=eta,
            batch_size=batch_size,
            epochs=epochs,
            burnin=burnin,
            mc_samples=mc_samples,
            prior=prior,
            spike=spike,
            slab=slab,
            pi=pi,
            k=k,
            ws=ws,
            steps=steps,
            model_dir=model_dir,
            results_dir=results_dir,
            device=device,
            verbose=verbose,
            init_weights=True,
    )

# move VICE to current device
vice.to(device)

# perform VICE optimization
vice.fit(train_batches=train_batches, val_batches=val_batches)

In [None]:
# get train and validation prediction accuracies, log-likelihoods, and complexity losses (KLDs)
train_accs = vice.train_accs
val_accs = vice.val_accs
loglikelihoods = vice.loglikelihoods
complexity_losses = vice.complexity_losses
latent_dimensions = vice.latent_dimensions

### Performance plots

In [None]:
# plot train and validation (prediction) accuracies against each other to examine whether model overfits the training data
visualization.plot_single_performance(
                                        plots_dir=plots_dir,
                                        val_accs=val_accs,
                                        train_accs=train_accs,
                                        steps=steps,
                                        show_plot=True
)

In [None]:
# plot complexity losses (KLDs) and log-likelihoods over time
visualization.plot_complexities_and_loglikelihoods(
                                                    plots_dir=plots_dir,
                                                    loglikelihoods=loglikelihoods,
                                                    complexity_losses=complexity_losses,
                                                    show_plot=True
)

In [None]:
# plot number of latent causes (selected dimensions after pruning) over time
visualization.plot_latent_causes(
                                plots_dir=plots_dir,
                                latent_causes=latent_dimensions,
                                show_plot=True
)

### Save unpruned locations and scales after convergence

In [None]:
# get means and standard deviations
params = vice.detached_params
W_loc=params['loc']
W_scale=params['scale']

In [None]:
# compress model params and store as binary files
with open(os.path.join(results_dir, 'parameters.npz'), 'wb') as f:
    np.savez_compressed(f, W_loc=W_loc, W_scale=W_scale)

### Load pruned locations and scales of converged VICE model for downstream applications

In [None]:
# load pruned VICE params
pruned_params = np.load(os.path.join(results_dir, 'pruned_params.npz'))
pruned_loc = pruned_params['pruned_loc']
pruned_scale = pruned_params['pruned_scale']