In [1]:
import warnings

warnings.filterwarnings("ignore")

import numpy as np

np.set_printoptions(precision=2, suppress=True)

import jax
from jax import numpy as jnp, random as jr

from inference import make_gradfun, sgd
from hmm.hmm import init_pgm_param, run_inference, rollout, gumbel_softmax
from hmm.network import init_mlp, make_loglike, relu, logits, gaussian_mean, identity

In [2]:
def generate_hmm_data(key, num_samples, time_steps, data_dim=2, num_states=2):
    A = jnp.array([[0.7, 0.3], [0.3, 0.7]]) 
    pi = jnp.full((num_states,), 1.0 / num_states) 

    emission_means = jnp.eye(num_states, data_dim) 
    log_A = jax.nn.log_softmax(A, axis=-1)  

    def sample_sequence(key, log_A, emission_means, time_steps):
        states, observations = [], []
        observations = []

        key, subkey = jr.split(key)
        state = jr.categorical(subkey, jnp.log(pi))
        states.append(state)
        observations.append(emission_means[state])

        for _ in range(1, time_steps):
            key, subkey = jr.split(key)
            state = jr.categorical(subkey, log_A[state])
            states.append(state)
            observations.append(emission_means[state])

        return jnp.array(states), jnp.stack(observations)

    keys = jr.split(key, num_samples)
    sample_sequence_vmap = jax.vmap(sample_sequence, in_axes=(0, None, None, None), out_axes=(0, 0))
    states, observations = sample_sequence_vmap(keys, log_A, emission_means, time_steps)
    return observations, states

In [3]:
def test(i, elbo, params, grad, print_every=10):
    if i % print_every == 0:
        print(f"epoch {i}: {elbo}")
        print(params[0])
        pgm_params, decoder_params, encoder_params = params

        idx = np.random.choice(range(batch_size))
        test_data = data[idx: idx + 1, :, :]

        rollout_key, local_key = jr.split(key)
        node_potentials = encoder(encoder_params, test_data)
        samples, _, _, _ = run_inference(rollout_key, pgm_prior_params, pgm_params, node_potentials, 20)
        decode_mean, _ = decoder(decoder_params, samples)

        print(samples[0])
        print(decode_mean[0])
        print(test_data)


key = jr.PRNGKey(0)

batch_size = 64
T, N, D = 4, 2, 2

data_key, key = jr.split(key)
data, _ = generate_hmm_data(data_key, batch_size * 100, T, D, N)

prior_key, key = jr.split(key)
pgm_prior_params = init_pgm_param(prior_key, N, alpha=0.0)

encoder_key, key = jr.split(key)
encoder, encoder_params = init_mlp(encoder_key, D, [(N, logits)])

decoder_key, key = jr.split(key)
decoder, decoder_params = init_mlp(decoder_key, N, [(2 * D, gaussian_mean)])
loglike = make_loglike(decoder)

pgm_key, key = jr.split(key)
pgm_params = init_pgm_param(pgm_key, N, alpha=0.0)
params = pgm_params, decoder_params, encoder_params

grad_key, key = jr.split(key)
gradfun = make_gradfun(grad_key, run_inference, encoder, loglike, pgm_prior_params, data, batch_size, 50, 1, test)

sgd_key, key = jr.split(key)
params = sgd(sgd_key, gradfun, params, 1000, 1e-3)

epoch 0: 435.7969665527344
[[0.66 0.01]
 [0.02 0.57]]
[[[0.3  0.7 ]
  [0.69 0.31]
  [0.44 0.56]
  [0.68 0.32]]]
[[[-0.01  0.02]
  [-0.01  0.02]
  [-0.01  0.02]
  [-0.01  0.02]]]
[[[1. 0.]
  [0. 1.]
  [0. 1.]
  [0. 1.]]]
epoch 10: 316.8983154296875
[[0.67 0.01]
 [0.02 0.58]]
[[[0.3  0.7 ]
  [0.68 0.32]
  [0.44 0.56]
  [0.68 0.32]]]
[[[0.52 0.48]
  [0.52 0.48]
  [0.52 0.48]
  [0.52 0.48]]]
[[[1. 0.]
  [0. 1.]
  [0. 1.]
  [1. 0.]]]
epoch 20: 295.2229309082031
[[0.68 0.02]
 [0.03 0.59]]
[[[0.3  0.7 ]
  [0.68 0.32]
  [0.43 0.57]
  [0.68 0.32]]]
[[[0.54 0.46]
  [0.54 0.46]
  [0.54 0.46]
  [0.54 0.46]]]
[[[0. 1.]
  [0. 1.]
  [0. 1.]
  [0. 1.]]]
epoch 30: 291.65740966796875
[[0.69 0.02]
 [0.03 0.6 ]]
[[[0.3  0.7 ]
  [0.68 0.32]
  [0.44 0.56]
  [0.69 0.31]]]
[[[0.47 0.53]
  [0.48 0.53]
  [0.47 0.53]
  [0.48 0.53]]]
[[[0. 1.]
  [0. 1.]
  [1. 0.]
  [1. 0.]]]
epoch 40: 290.22271728515625
[[0.7  0.02]
 [0.04 0.61]]
[[[0.29 0.71]
  [0.69 0.31]
  [0.43 0.57]
  [0.69 0.31]]]
[[[0.5  0.5 ]
  [0.51 0.49