In [1]:
from scipy import io
from pathlib import Path

import numpy as np
# rng = np.random.default_rng()

import jax.numpy as jnp
import jax.random as jr

import optax

from collections import defaultdict
import pickle as pkl

import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

from dynamax.hidden_markov_model import ET_HMM, E_HMM, T_HMM
from dynamax.utils.plotting import gradient_cmap

I0000 00:00:1725645768.220247  286664 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [2]:
# Load calibrated design matrix and observation
# design matrix:
#   - 1st column: Stimulus (calibrated)
#   - 2nd column: Coherence
#   - 3rd column: Attention
#   - 4th column: Expectation (calibrated)
design_matrix = np.load("../analysis/cache/exp_design.npy")
observation = np.load("../analysis/cache/exp_obs.npy")

design_matrix.shape, observation.shape

((260, 120, 4), (260, 120, 1))

In [3]:
# Convert to radians
design_matrix[:,:,0] = np.deg2rad(design_matrix[:,:,0])
design_matrix[:,:,3] = np.deg2rad(design_matrix[:,:,3])

observation = np.deg2rad(observation)

# Add flat biases column
flat_biases = np.ones_like(design_matrix[:,:,:1])
design_matrix =  np.concatenate([design_matrix, flat_biases], axis=-1)

In [4]:
# Shuffle data
shuff_idx = np.random.permutation(len(design_matrix))

design_matrix = design_matrix[shuff_idx]
observation = observation[shuff_idx]

In [5]:
# Split it 80:20 for training and testing

train_idx = int(0.8 * len(design_matrix))

train_design = design_matrix[:train_idx]
train_obs = observation[:train_idx]

test_design = design_matrix[train_idx:]
test_obs = observation[train_idx:]

In [6]:
# convert to jax arrays
train_emissions = jnp.array(train_obs)
train_inputs = jnp.array(train_design)

In [7]:
# A training cache to store all the models and their parameters
training_cache = defaultdict(dict)

# DS to store the models organized
class Model_Store:
    n_states: int
    fit_model: None
    fit_params: None
    lps: None

In [8]:
def train_hmms(model_class, num_states, emission_dim, input_dim, train_emissions, train_inputs):
    
    model = model_class(num_states, input_dim, emission_dim)
    parameters, properties = model.initialize(key=jr.PRNGKey(1))

    # Fit with SGD
    fit_params, lps = model.fit_sgd(params = parameters, 
                                    props = properties, 
                                    emissions = train_emissions, 
                                    inputs = train_inputs, 
                                    num_epochs = 5000, 
                                    optimizer = optax.adam(1e-4), 
                                    shuffle = False, 
                                    batch_size = 32)
    
    return model, fit_params, lps

In [9]:
# for i in range(10):
#     model, fit_params, lps = train_hmms(E_HMM, 2, 1, 5, train_emissions, train_inputs)
#     print(f"Model {i} trained - {lps[:10]}")

In [10]:
min_state, max_state = 3, 3
emission_dim, input_dim = 1, 5

for model_class in [ET_HMM, E_HMM, T_HMM]:

    for num_states in range(min_state, max_state+1):
        print(f'Training: {model_class.__name__} - {num_states} states')

        model, fit_params, lps = train_hmms(model_class, num_states, emission_dim, input_dim, train_emissions, train_inputs)

        print(f"{lps[:10]}")

        t_store = Model_Store()
        t_store.n_states = num_states
        t_store.fit_model = model
        t_store.fit_params = fit_params
        t_store.lps = lps

        training_cache[model_class.__name__][num_states] = t_store


Training: ET_HMM - 3 states
[2.0024195 1.9983617 1.9943178 1.9902885 1.9862753 1.982278  1.9782981
 1.9743353 1.9703901 1.9664623]
Training: E_HMM - 3 states
[2.0014584 1.9973913 1.9933379 1.9892995 1.9852773 1.9812715 1.9772825
 1.9733113 1.9693574 1.9654208]
Training: T_HMM - 3 states
[2.005565  2.004551  2.0035405 2.0025337 2.0015306 2.0005314 1.9995363
 1.9985449 1.9975579 1.9965748]


In [12]:
training_cache["ET_HMM"][3].fit_params

ParamsET_HMM(initial=ParamsStandardHMMInitialState(probs=Array([0.20569214, 0.33288547, 0.46142238], dtype=float32)), transitions=ParamsET_Transitions(transition_matrix=Array([[0.15079704, 0.7906471 , 0.05855593],
       [0.02953597, 0.7138907 , 0.2565734 ],
       [0.01489414, 0.9704076 , 0.01469822]], dtype=float32), transition_weights=Array([[-0.00295883, -0.02011733,  0.00127638,  0.00208977, -0.01400751],
       [-0.00447574,  0.00723497,  0.00549877, -0.01346807,  0.00572236],
       [ 0.005439  , -0.00175116, -0.00196063, -0.00147497,  0.00494916]],      dtype=float32)), emissions=ParamsET_Emissions(weights=Array([[[ 0.9441166 ,  0.00499397,  0.02183866,  0.05175149,
         -0.04561728]],

       [[-0.03443658, -0.0204086 ,  0.04581711,  0.02254143,
          0.9817551 ]],

       [[-0.02108398,  0.05295916, -0.00896955,  0.00308305,
          2.3185802 ]]], dtype=float32), covs=Array([[[43.99967  ]],

       [[ 4.0123854]],

       [[12.004267 ]]], dtype=float32)))

In [13]:
# training_cache["shuffle_idx"] = shuff_idx

# # Save the trained data so we dont have to train again and again
# with open('training_cache_post.pkl', 'wb') as f:
#     pkl.dump(training_cache, f)