In [11]:
from scipy import io
from pathlib import Path

import numpy as np
# rng = np.random.default_rng()

import jax.numpy as jnp
import jax.random as jr

import optax

from collections import defaultdict
import pickle as pkl

import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

from dynamax.hidden_markov_model import LinearRegressionHMM
from dynamax.utils.plotting import gradient_cmap

In [12]:
# Load calibrated data
# data:
# final_resp_loc_jx
# final_resp_loc_jy
# final_resp_vel_jx
# final_resp_vel_jy
# final_resp_acc_jx
# final_resp_acc_jy
# stimulus_x
# stimulus_y
# coh
# att

data = np.load("../analysis/cache/total_data.npy")

In [13]:
data.shape

(10, 13, 4, 6, 120)

In [14]:
coh = data[8].reshape(13, 24, 120)
att = data[9].reshape(13, 24, 120)

coh.shape, att.shape

((13, 24, 120), (13, 24, 120))

In [15]:
# Extract emissions and design matrix
resp_x = data[0]
resp_y = data[1]
total_vel = np.sqrt(data[2]**2 + data[3]**2)
total_acc = np.sqrt(data[4]**2 + data[5]**2)

stim_x = data[6]
stim_y = data[7]
flat_biases = np.ones_like(stim_y)

emissions = np.stack([resp_x, resp_y, total_vel, total_acc], axis=-1).reshape(13, 24, 120, 4)
design_matrix = np.stack([stim_x, stim_y, flat_biases], axis=-1).reshape(13, 24, 120, 3)

# separate data based on coh = 1 or coh = 0
emissions_coh_1 = emissions[np.where(coh == 1)].reshape(13, 24, 60, 4)
design_matrix_coh_1 = design_matrix[np.where(coh == 1)].reshape(13, 24, 60, 3)

emissions_coh_0 = emissions[np.where(coh == 0)].reshape(13, 24, 60, 4)
design_matrix_coh_0 = design_matrix[np.where(coh == 0)].reshape(13, 24, 60, 3)

# Stack them at a new axis
emissions = np.stack([emissions_coh_0, emissions_coh_1], axis=0)
design_matrix = np.stack([design_matrix_coh_0, design_matrix_coh_1], axis=0)

emissions.shape, design_matrix.shape

((2, 13, 24, 60, 4), (2, 13, 24, 60, 3))

In [16]:
def train_hmms(model_class, num_states, emission_dim, input_dim, train_emissions, train_inputs):
    
    model = model_class(num_states, input_dim, emission_dim)
    parameters, properties = model.initialize(key=jr.PRNGKey(1))

    # Fit with SGD
    fit_params, lps = model.fit_sgd(params = parameters, 
                                    props = properties, 
                                    emissions = train_emissions, 
                                    inputs = train_inputs, 
                                    num_epochs = 5000, 
                                    optimizer = optax.adam(1e-4), 
                                    shuffle = True, 
                                    batch_size = 8)
    
    return model, fit_params, lps

def cross_validate(model, all_params, emissions, inputs):
    marg_log_probs = []
    for validation_idx in range(len(emissions)):
        log_prob = model.marginal_log_prob(all_params, emissions[validation_idx], inputs=inputs[validation_idx])
        marg_log_probs.append(float(log_prob))
    return np.array(marg_log_probs)

In [55]:
# A training cache to store all the models and their parameters
def nested_defaultdict(): return defaultdict(nested_defaultdict)
training_cache = nested_defaultdict()

# DS to store the models organized
class Model_Store:
    subject_id: int
    n_states: int
    fit_model: None
    fit_params: None
    lps: None
    valid_mllk: None
    test_idx: None

In [56]:
min_state, max_state = 2, 6
emission_dim, input_dim = 4, 3

# Training and validation loop
for coh_idx in range(2):

    for subject_idx in range(13):
        # Split the data into training and testing randomly
        # 4 samples for testing
        # 20 samples for training
        train_idx = np.random.choice(24, 20, replace=False)
        test_idx = np.setdiff1d(np.arange(24), train_idx)

        # Train data
        train_emissions = jnp.array(emissions[coh_idx][subject_idx][train_idx])
        train_inputs = jnp.array(design_matrix[coh_idx][subject_idx][train_idx])

        # Test data
        test_emissions = jnp.array(emissions[coh_idx][subject_idx][test_idx])
        test_inputs = jnp.array(design_matrix[coh_idx][subject_idx][test_idx])

        for model_class in [LinearRegressionHMM]:

            for num_states in range(min_state, max_state+1):


                model, fit_params, lps = train_hmms(model_class, 
                                                    num_states, 
                                                    emission_dim, 
                                                    input_dim, 
                                                    train_emissions, 
                                                    train_inputs)
                
                t_store = Model_Store()
                t_store.n_states = num_states
                t_store.fit_model = model
                t_store.fit_params = fit_params
                t_store.lps = lps
                t_store.valid_mllk = cross_validate(model, fit_params, test_emissions, test_inputs)
                t_store.test_idx = test_idx

                print(f'Trainined: mllk - {t_store.valid_mllk.mean():.3f} sub - {subject_idx} {model_class.__name__} - {num_states} states')
                
                training_cache[coh_idx][subject_idx][num_states] = t_store


Trainined: mllk - 37.832 sub - 0 LinearRegressionHMM - 2 states
Trainined: mllk - 37.322 sub - 0 LinearRegressionHMM - 3 states
Trainined: mllk - 52.132 sub - 0 LinearRegressionHMM - 4 states
Trainined: mllk - 59.008 sub - 0 LinearRegressionHMM - 5 states
Trainined: mllk - 50.015 sub - 0 LinearRegressionHMM - 6 states
Trainined: mllk - 66.002 sub - 1 LinearRegressionHMM - 2 states
Trainined: mllk - 40.728 sub - 1 LinearRegressionHMM - 3 states
Trainined: mllk - 77.138 sub - 1 LinearRegressionHMM - 4 states
Trainined: mllk - 69.787 sub - 1 LinearRegressionHMM - 5 states
Trainined: mllk - 59.140 sub - 1 LinearRegressionHMM - 6 states
Trainined: mllk - -21.259 sub - 2 LinearRegressionHMM - 2 states
Trainined: mllk - -13.101 sub - 2 LinearRegressionHMM - 3 states
Trainined: mllk - 20.074 sub - 2 LinearRegressionHMM - 4 states
Trainined: mllk - 11.399 sub - 2 LinearRegressionHMM - 5 states
Trainined: mllk - 33.004 sub - 2 LinearRegressionHMM - 6 states
Trainined: mllk - 37.164 sub - 3 Linea

In [57]:
# # Save the trained data so we dont have to train again and again
# with open('all_subject_coh_train_report.pkl', 'wb') as f:
#     pkl.dump(training_cache, f)

In [3]:
# Convert to radians
# design_matrix[:,:,0] = np.deg2rad(design_matrix[:,:,0])
# design_matrix[:,:,3] = np.deg2rad(design_matrix[:,:,3])

# observation = np.deg2rad(observation)

# Add flat biases column
flat_biases = np.ones_like(design_matrix[:,:,:1])
design_matrix =  np.concatenate([design_matrix, flat_biases], axis=-1)

In [4]:
# Shuffle data
shuff_idx = np.random.permutation(len(design_matrix))

design_matrix = design_matrix[shuff_idx]
observation = observation[shuff_idx]

In [5]:
# Split it 80:20 for training and testing

train_idx = int(0.8 * len(design_matrix))

train_design = design_matrix[:train_idx]
train_obs = observation[:train_idx]

test_design = design_matrix[train_idx:]
test_obs = observation[train_idx:]

In [6]:
# convert to jax arrays
train_emissions = jnp.array(train_obs)
train_inputs = jnp.array(train_design)

In [7]:
# A training cache to store all the models and their parameters
training_cache = defaultdict(dict)

# DS to store the models organized
class Model_Store:
    n_states: int
    fit_model: None
    fit_params: None
    lps: None

In [8]:
def train_hmms(model_class, num_states, emission_dim, input_dim, train_emissions, train_inputs):
    
    model = model_class(num_states, input_dim, emission_dim)
    parameters, properties = model.initialize(key=jr.PRNGKey(1))

    # Fit with SGD
    fit_params, lps = model.fit_sgd(params = parameters, 
                                    props = properties, 
                                    emissions = train_emissions, 
                                    inputs = train_inputs, 
                                    num_epochs = 5000, 
                                    optimizer = optax.adam(1e-4), 
                                    shuffle = False, 
                                    batch_size = 32)
    
    return model, fit_params, lps

In [9]:
# for i in range(10):
#     model, fit_params, lps = train_hmms(E_HMM, 2, 1, 5, train_emissions, train_inputs)
#     print(f"Model {i} trained - {lps[:10]}")

In [10]:
min_state, max_state = 2, 6
emission_dim, input_dim = 1, 5

for model_class in [ET_HMM, E_HMM, T_HMM]:

    for num_states in range(min_state, max_state+1):
        print(f'Training: {model_class.__name__} - {num_states} states')

        model, fit_params, lps = train_hmms(model_class, num_states, emission_dim, input_dim, train_emissions, train_inputs)

        print(f"{lps[:10]}")

        t_store = Model_Store()
        t_store.n_states = num_states
        t_store.fit_model = model
        t_store.fit_params = fit_params
        t_store.lps = lps

        training_cache[model_class.__name__][num_states] = t_store


Training: ET_HMM - 2 states
[1.9862871 1.9824057 1.9785368 1.9746825 1.9708424 1.9670181 1.9632099
 1.9594175 1.9556417 1.9518824]
Training: ET_HMM - 3 states
[1.9957613 1.9918696 1.98799   1.9841245 1.9802731 1.9764369 1.9726163
 1.9688113 1.9650222 1.9612494]
Training: ET_HMM - 4 states
[1.9962769 1.9923809 1.9884971 1.9846271 1.9807719 1.9769317 1.973107
 1.9692982 1.9655054 1.9617283]
Training: ET_HMM - 5 states
[1.9953004 1.9914081 1.9875284 1.9836622 1.9798106 1.9759741 1.9721533
 1.9683479 1.964559  1.9607859]
Training: ET_HMM - 6 states
[1.9969393 1.9930414 1.989156  1.9852844 1.9814271 1.9775847 1.9737579
 1.969947  1.9661516 1.9623725]
Training: E_HMM - 2 states
[1.9859469 1.9820644 1.9781945 1.9743389 1.970498  1.9666725 1.9628631
 1.95907   1.955293  1.9515327]
Training: E_HMM - 3 states
[1.994876  1.9909754 1.9870872 1.9832131 1.9793533 1.9755088 1.9716797
 1.9678665 1.9640694 1.9602884]
Training: E_HMM - 4 states
[1.9923513 1.9884495 1.9845606 1.9806852 1.9768249 1.972979

In [13]:
training_cache["ET_HMM"][2].fit_params

ParamsET_HMM(initial=ParamsStandardHMMInitialState(probs=Array([0.65643466, 0.34356534], dtype=float32)), transitions=ParamsET_Transitions(transition_matrix=Array([[0.93811536, 0.06188458],
       [0.94132054, 0.05867945]], dtype=float32), transition_weights=Array([[-0.00764516, -0.00025098, -0.01133804, -0.00038295, -0.00130757],
       [-0.00110481,  0.01022496, -0.00978904,  0.00814341,  0.00280931]],      dtype=float32)), emissions=ParamsET_Emissions(weights=Array([[[ 0.216029  ,  0.02790617,  0.01633381,  0.03509545,
          1.3987256 ]],

       [[ 0.06578622, -0.02190064,  0.04457312,  0.00982139,
          0.39937398]]], dtype=float32), covs=Array([[[ 3.292194]],

       [[11.952794]]], dtype=float32)))

In [14]:
# training_cache["shuffle_idx"] = shuff_idx

# # Save the trained data so we dont have to train again and again
# with open('training_cache.pkl', 'wb') as f:
#     pkl.dump(training_cache, f)