In [8]:
# glm_hmm_utils.py

# Functions to assist with GLM-HMM model fitting
import sys
import ssm
import autograd.numpy as np
import autograd.numpy.random as npr


def load_data(animal_file):
    container = np.load(animal_file, allow_pickle=True)
    data = [container[key] for key in container]
    inpt = data[0]
    y = data[1]
    session = data[2]
    return inpt, y, session


def load_cluster_arr(cluster_arr_file):
    container = np.load(cluster_arr_file, allow_pickle=True)
    data = [container[key] for key in container]
    cluster_arr = data[0]
    return cluster_arr


def load_glm_vectors(glm_vectors_file):
    container = np.load(glm_vectors_file)
    data = [container[key] for key in container]
    loglikelihood_train = data[0]
    recovered_weights = data[1]
    return loglikelihood_train, recovered_weights


def load_global_params(global_params_file):
    container = np.load(global_params_file, allow_pickle=True)
    data = [container[key] for key in container]
    global_params = data[0]
    return global_params


def partition_data_by_session(inpt, y, mask, session):
    '''
    Partition inpt, y, mask by session
    :param inpt: arr of size TxM
    :param y:  arr of size T x D
    :param mask: Boolean arr of size T indicating if element is violation or
    not
    :param session: list of size T containing session ids
    :return: list of inpt arrays, data arrays and mask arrays, where the
    number of elements in list = number of sessions and each array size is
    number of trials in session
    '''
    inputs = []
    datas = []
    indexes = np.unique(session, return_index=True)[1]
    unique_sessions = [session[index] for index in sorted(indexes)]
    counter = 0
    masks = []
    for sess in unique_sessions:
        idx = np.where(session == sess)[0]
        counter += len(idx)
        inputs.append(inpt[idx, :])
        datas.append(y[idx, :])
        masks.append(mask[idx, :])
    assert counter == inpt.shape[0], "not all trials assigned to session!"
    return inputs, datas, masks


def load_session_fold_lookup(file_path):
    container = np.load(file_path, allow_pickle=True)
    data = [container[key] for key in container]
    session_fold_lookup_table = data[0]
    return session_fold_lookup_table


def load_animal_list(file):
    container = np.load(file, allow_pickle=True)
    data = [container[key] for key in container]
    animal_list = data[0]
    return animal_list


def create_violation_mask(violation_idx, T):
    """
    Return indices of nonviolations and also a Boolean mask for inclusion (1
    = nonviolation; 0 = violation)
    :param test_idx:
    :param T:
    :return:
    """
    mask = np.array([i not in violation_idx for i in range(T)])
    nonviolation_idx = np.arange(T)[mask]
    mask = mask + 0
    assert len(nonviolation_idx) + len(
        violation_idx
    ) == T, "violation and non-violation idx do not include all dta!"
    return nonviolation_idx, np.expand_dims(mask, axis=1)


In [9]:
def launch_glm_hmm_job(inpt, y, session, mask, session_fold_lookup_table, K, D,
                       C, N_em_iters, transition_alpha, prior_sigma, fold,
                       iter, global_fit, init_param_file, save_directory):
    print("Starting inference with K = " + str(K) + "; Fold = " + str(fold) +
          "; Iter = " + str(iter))
    sys.stdout.flush()
    sessions_to_keep = session_fold_lookup_table[np.where(
        session_fold_lookup_table[:, 1] != fold), 0]
    idx_this_fold = [str(sess) in sessions_to_keep for sess in session]
    this_inpt, this_y, this_session, this_mask = inpt[idx_this_fold, :], \
                                                 y[idx_this_fold, :], \
                                                 session[idx_this_fold], \
                                                 mask[idx_this_fold]
    # Only do this so that errors are avoided - these y values will not
    # actually be used for anything (due to violation mask)
    this_y[np.where(this_y == -1), :] = 1
    inputs, datas, masks = partition_data_by_session(
        this_inpt, this_y, this_mask, this_session)
    # Read in GLM fit if global_fit = True:
    _, params_for_initialization = load_glm_vectors(init_param_file)
    M = this_inpt.shape[1]
    npr.seed(iter)
    return fit_glm_hmm(datas,
                inputs,
                masks,
                K,
                D,
                M,
                C,
                N_em_iters,
                transition_alpha,
                prior_sigma,
                global_fit,
                params_for_initialization,
                save_title=save_directory + 'glm_hmm_raw_parameters_itr_' +
                           str(iter) + '.npz')
    
def fit_glm_hmm(datas, inputs, masks, K, D, M, C, N_em_iters,
                transition_alpha, prior_sigma, global_fit,
                params_for_initialization, save_title):
    '''
    Instantiate and fit GLM-HMM model
    :param datas:
    :param inputs:
    :param masks:
    :param K:
    :param D:
    :param M:
    :param C:
    :param N_em_iters:
    :param global_fit:
    :param glm_vectors:
    :param save_title:
    :return:
    '''
    # Prior variables
    # Choice of prior
    this_hmm = ssm.HMM(K,
                        D,
                        M,
                        observations="input_driven_obs",
                        observation_kwargs=dict(C=C,
                                                prior_sigma=prior_sigma),
                        transitions="sticky",
                        transition_kwargs=dict(alpha=transition_alpha,
                                                kappa=0))
    # Initialize observation weights as GLM weights with some noise:
    glm_vectors_repeated = np.tile(params_for_initialization, (K, 1, 1))
    glm_vectors_with_noise = glm_vectors_repeated + np.random.normal(
        0, 0.2, glm_vectors_repeated.shape)
    this_hmm.observations.params = glm_vectors_with_noise
    print("=== fitting GLM-HMM ========")
    sys.stdout.flush()

    print("datas shape: " + str(np.array(datas).shape))
    print("inputs shape: " + str(np.array(inputs).shape))
    print("masks shape: " + str(np.array(masks).shape))

    return datas, inputs, masks
    

    # Fit this HMM and calculate marginal likelihood
    # lls = this_hmm.fit(datas,
    #                    inputs=inputs,
    #                    masks=masks,
    #                    method="em",
    #                    num_iters=N_em_iters,
    #                    initialize=False,
    #                    tolerance=10 ** -4)
    # # Save raw parameters of HMM, as well as loglikelihood during training
    # np.savez(save_title, this_hmm.params, lls)
    # return None

In [10]:
# 1_run_inference_global_fit_ibl.py

import sys
import os
import autograd.numpy as np


D = 1  # data (observations) dimension
C = 2  # number of output types/categories
N_em_iters = 300  # number of EM iterations

USE_CLUSTER = False

data_dir = '/home/rudra/Desktop/markov_models/data/ibl/data_for_cluster/'
results_dir = '/home/rudra/Desktop/markov_models/results/ibl_global_fit/'

if USE_CLUSTER:
    z = int(sys.argv[1])
else:
    z = 0

num_folds = 5
global_fit = True
# perform mle => set transition_alpha to 1
transition_alpha = 1
prior_sigma = 100

# Load external files:
cluster_arr_file = data_dir + 'cluster_job_arr.npz'
# Load cluster array job parameters:
cluster_arr = load_cluster_arr(cluster_arr_file)
[K, fold, iter] = cluster_arr[z]

#  read in data and train/test split
animal_file = data_dir + 'all_animals_concat.npz'
session_fold_lookup_table = load_session_fold_lookup(
    data_dir + 'all_animals_concat_session_fold_lookup.npz')

inpt, y, session = load_data(animal_file)
#  append a column of ones to inpt to represent the bias covariate:
inpt = np.hstack((inpt, np.ones((len(inpt),1))))
y = y.astype('int')
# Identify violations for exclusion:
violation_idx = np.where(y == -1)[0]
nonviolation_idx, mask = create_violation_mask(violation_idx,
                                                inpt.shape[0])

#  GLM weights to use to initialize GLM-HMM
init_param_file = results_dir + '/GLM/fold_' + str(
    fold) + '/variables_of_interest_iter_0.npz'

# create save directory for this initialization/fold combination:
save_directory = results_dir + '/GLM_HMM_K_' + str(
    K) + '/' + 'fold_' + str(fold) + '/' + '/iter_' + str(iter) + '/'
if not os.path.exists(save_directory):
    os.makedirs(save_directory)

data, inputs, masks = launch_glm_hmm_job(inpt,
                                            y,
                                            session,
                                            mask,
                                            session_fold_lookup_table,
                                            K,
                                            D,
                                            C,
                                            N_em_iters,
                                            transition_alpha,
                                            prior_sigma,
                                            fold,
                                            iter,
                                            global_fit,
                                            init_param_file,
                                            save_directory)


Starting inference with K = 2; Fold = 0; Iter = 0
datas shape: (1618, 90, 1)
inputs shape: (1618, 90, 4)
masks shape: (1618, 90, 1)


In [13]:
data = np.array(data)
inputs = np.array(inputs)
masks = np.array(masks)

data.shape, inputs.shape, masks.shape

((1618, 90, 1), (1618, 90, 4), (1618, 90, 1))

In [16]:
data[0].reshape(-1)

array([1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0,
       1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1,
       0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1,
       0, 1])

In [24]:
masks[7].reshape(-1)

array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1])

# The inputs values

```
def create_design_mat(choice, stim_left, stim_right, rewarded):
    # Create unnormalized_inpt: with first column = stim_right - stim_left,
    # second column as past choice, third column as WSLS
    
```

# stim_right - stim_left, past choice, WSLS, Bias


```
wsls: vector of size T, entries are in {-1, 1}.  1 corresponds to
previous choice = right and success OR previous choice = left and
failure; -1 corresponds to previous choice = left and success OR previous choice = right and failure
```

In [25]:
inputs[0].reshape(-1)

array([-2.60674147e-01, -1.00000000e+00, -1.00000000e+00,  1.00000000e+00,
        1.11351692e-03, -1.00000000e+00, -1.00000000e+00,  1.00000000e+00,
       -2.09318779e+00, -1.00000000e+00, -1.00000000e+00,  1.00000000e+00,
       -1.29780315e-01, -1.00000000e+00, -1.00000000e+00,  1.00000000e+00,
       -2.60674147e-01, -1.00000000e+00, -1.00000000e+00,  1.00000000e+00,
        1.11351692e-03, -1.00000000e+00, -1.00000000e+00,  1.00000000e+00,
        1.11351692e-03, -1.00000000e+00,  1.00000000e+00,  1.00000000e+00,
        5.24688844e-01, -1.00000000e+00,  1.00000000e+00,  1.00000000e+00,
       -1.29780315e-01,  1.00000000e+00,  1.00000000e+00,  1.00000000e+00,
       -1.29780315e-01, -1.00000000e+00, -1.00000000e+00,  1.00000000e+00,
        2.09541483e+00, -1.00000000e+00, -1.00000000e+00,  1.00000000e+00,
        2.09541483e+00,  1.00000000e+00,  1.00000000e+00,  1.00000000e+00,
        1.32007349e-01,  1.00000000e+00,  1.00000000e+00,  1.00000000e+00,
       -2.09318779e+00, -