In [12]:
from functools import partial
import os
from pathlib import Path
import pickle
from typing import Literal, NamedTuple

from gabenet.mcmc import sample_markov_chain
from gabenet.nets import MultinomialDirichletBelieve
from gabenet.utils import freeze_trainable_states, perplexity
import haiku as hk
import jax
from jax import random
import jax.numpy as jnp

from dataset import load_mutation_spectrum, COSMIC_WEIGHTS

In [13]:
import os
ARTEFACT_DIR = Path(os.path.abspath("")).parent
# Training hyperparameters
# Pseudo-random number generator sequence.
RANDOM_SEED = 43
N_BURNIN = 200
LOG_EVERY = 200
N_SAMPLES = 2_000
CONTEXT = 96

# Print out training hyperparameters for logging.
print(f"CONTEXT = {CONTEXT}")
print(f"RANDOM_SEED = {RANDOM_SEED}")
print(f"N_BURNIN = {N_BURNIN}")
print(f"LOG_EVERY = {LOG_EVERY}")
print(f"N_SAMPLES = {N_SAMPLES}")

# Model hyperparameters.
MODEL: Literal[
    "multinomial_dirichlet_believe", "poisson_gamma_believe"
] = "multinomial_dirichlet_believe"
n_topics = len(COSMIC_WEIGHTS)
# Network config after pruning.
# Updated from 78 to 86 in 3.4
HIDDEN_LAYER_SIZES = [41, 86]
GAMMA_0 = 10.0
_bottom_layer_name = (
    f"{MODEL}/~/multinomial_layer"
    if MODEL == "multinomial_dirichlet_believe"
    else f"{MODEL}/~/poisson_layer"
)
# Print out model hyperparameters for logging.
print(f"MODEL = {MODEL}")
print(f"n_topics = {n_topics}")
print(f"HIDDEN_LAYER_SIZES = {HIDDEN_LAYER_SIZES}")
print(f"GAMMA_0 = {GAMMA_0}")

X_train, X_test, n_features = load_mutation_spectrum(context=CONTEXT)

CONTEXT = 96
RANDOM_SEED = 43
N_BURNIN = 200
LOG_EVERY = 200
N_SAMPLES = 2000
MODEL = multinomial_dirichlet_believe
n_topics = 86
HIDDEN_LAYER_SIZES = [41, 86]
GAMMA_0 = 10.0


In [14]:
key_seq = hk.PRNGSequence(43)

class TrainState(NamedTuple):
    params: hk.Params
    state: hk.State
    key: jax.Array  # type: ignore
    step: int

@hk.transform_with_state
def kernel(X=X_train, freeze_phi=True):
    """Advance the Markov chain by one step."""
    model = MultinomialDirichletBelieve(
        HIDDEN_LAYER_SIZES, n_features, gamma_0=GAMMA_0
    )
    if freeze_phi:
        model.layers.layers[-1].set_training(False)
    # Do one Gibbs sampling step.
    model(X)
    
def probability(params, state):
    bottom_params = params.get("multinomial_dirichlet_believe/~/multinomial_layer", {})
    bottom_state = state["multinomial_dirichlet_believe/~/multinomial_layer"]
    phi = bottom_params.get("phi", bottom_state.get("phi"))
    theta = bottom_state["theta"]
    return theta @ phi

def initialise(key) -> TrainState:
    """Initialise training state."""
    key, subkey = random.split(key)
    keys = random.split(subkey, jax.device_count())
    params, state = jax.vmap(partial(kernel.init, freeze_phi=False), in_axes=[0, None])(
        keys, X_train
    )
    params, state = freeze_trainable_states(state, variable_names=["phi"])
    params["multinomial_dirichlet_believe/~/multinomial_layer"]["phi"] = jnp.array(COSMIC_WEIGHTS)
    return TrainState(params, state, key, 0)

def evaluate(params, states, X, axis=[0, 1]):
    """Compute perplexity over chains and samples by default (axis=[0, 1])."""
    probs = probability(params, states).mean(axis)
    return perplexity(X, probs)

In [None]:
key = jax.random.PRNGKey(RANDOM_SEED)
train_state = initialise(key)