# Running the EM algorithm to infer transition probabilities and GLM weights per state

The details of the model are very well described in the paper methods. 


## Imports

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import GoalSelection.training_metrics as tm
from pathlib import Path
import glmhmm.glm as glm
import glmhmm.glm_hmm as glm_hmm
import FlexiVexi_glm.design_matrix as dm
import FlexiVexi_glm.visualize as viz
import glmhmm.utils as uti

## Declarations

In [None]:
DATA = Path('/Volumes/sjones/projects/FlexiVexi/behavioural_data')
MOUSE = 'FNT103'
DATE = '2024-08-02'
PORTS = [[0.6, 0.35], 
         [-0.6, 0.35], 
         [0, -0.7]]

BIAS = True

exp_data = tm.build_exp_data(MOUSE, DATE)


## Fit mouse-wide GLM
The main use of this is a sanity check of the data and making the mouse-wide design matrix, hence not  using a bias term. 


In [None]:
X, y, row_identity, design_concat = dm.design_matrix_per_mouse(MOUSE, -19, bias = BIAS)
GLM = dm.build_GLM(design_concat, y)
w_init =  GLM.init_weights()
w, phi  = GLM.fit(X, w_init, y)
fig, ax = viz.plot_model_weights(MOUSE, GLM, bias = BIAS)


## Now, use the GLMHMM class

In [None]:
def build_GLMHMM(design, y, states, observations='bernoulli', return_params = False):
        
    '''
    c  is 2  in a  bernouilli choice. 

    n: number of data/time points
    d: number of features (inputs to design matrix)
    c: number of classes (possible observations)
    x: design matrix (nxm)
    y: observations (nxc)
    w: weights mapping x to y (mxc or mx1)
    '''
    n = len(design)
    d = len(design.columns)-1
    if observations == 'bernoulli':
        c = 2   
    else:
        print('Think about your number of observations!')

    k = states

    GLMHMM = glm_hmm.GLMHMM(n, d, c, k) 

    params = {'n':n, 
              'd':d, 
              'c':c, 
              'k':k}

    if return_params:
        return GLMHMM, params
    else:
        return GLMHMM

GLMHMM, params = build_GLMHMM(design_concat, y, states = 3, return_params = True)

Generates parameters A, w, and pi for a GLM-HMM. Can be used to generate true parameters for simulated data
or to initialize parameters for fitting. I don't see why it doesn't make sense to have a uniform prior here. 

Parameters:

- weights : list, optional  
    Contains the name of the desired distribution (string) and optionally the associated parameters 
    (see init_params.py script for details. The default is ['uniform',-1,1,1].  

- transitions : list, optional  
    Contains the name of the desired distribution (string). The default is ['dirichlet',5,1].  

- state_priors : string, optional  
    Containts the name of the desired distribution (string). The default is None, or 'uniform'.  
    

Returns:

A : kxk matrix of transition probabilities.
w : mxc matrix of weights.
pi : kx1 vector of state probabilities for t=1.

In [None]:
A, w, pi = GLMHMM.generate_params(weights=['uniform',-1,1,1])


Let's first try to do 2 initializations and initialise the weights as a uniform distribution. 

In [None]:
inits = 2 # set the number of initializations

# store values for each initialization
lls_all = np.zeros((inits,250))
A_all = np.zeros((inits,params['k'],params['k']))
w_all = np.zeros((inits,params['k'],params['d'],params['c']))

# fit the model for each initialization
for i in range(inits):
    A_init,w_init,pi_init = GLMHMM.generate_params() # initialize the model parameters
    lls_all[i,:],A_all[i,:,:],w_all[i,:,:],pi0 = GLMHMM.fit(y,X,A_init,w_init) # fit the model
    print('initialization %s complete' %(i+1))

In [None]:
bestix = uti.find_best_fit(lls_all) # find the initialization that led to the best fit


In [None]:
bestix

`w_all` is a tensor shaped initialisations x states x regressors x observations

In [None]:
w_all.shape

In [None]:
weights_end = w_all[bestix,  :, :, :]
weights_end.shape

In [None]:
params['n']

In [None]:
xlabels = [
    'Cue identity',
    'History of last choice 1',
    'History of last choice 2',
    'History of last choice 3',
    'History of last choice 4',
    'History of last choice 5',
    'Last rewarded choice',
    'Distance to 0',
    'Distance to 1',
    'bias'
]
fig, ax = plt.subplots(params['k'])
for i in range(params['k']):
    ax[i].plot(weights_end[i,:,:])

ax[params['k']-1].set_xticks(np.arange(0,len(xlabels)))
ax[params['k']-1].plot(xlabels,np.zeros((len(xlabels),1)),'k--')
ax[params['k']-1].set_xticklabels(xlabels, rotation =  90)
trials = params['n']
states = params['k']
fig.suptitle(f'GLMHMM weights for {MOUSE}, {trials} trials, {states} states')




It basically picks up spatial bias

## Being a bit smarter about the weights

We will now use the GLM fit weights plus a bit of noise to explain the mouse behaviour. The generate_params method calls a function un init_params.py that employs as parameters:

- A low and a high estimates to initialise  the glm weights (they are also approximated, not computer analytically)
- The X and y of the complete GLM
- A bias term. As far as I can see, bias is not used, but the system will add a bias term itself, so I should'nt do it!

In [None]:
glm_prior_params = ['GLM', -0.2, 1.2, X, y, 1]

In [None]:
inits = 2 # set the number of initializations

# store values for each initialization
lls_all = np.zeros((inits,250))
A_all = np.zeros((inits,params['k'],params['k']))
w_all = np.zeros((inits,params['k'],params['d'],params['c']))

# fit the model for each initialization
for i in range(inits):
    A_init,w_init,pi_init = GLMHMM.generate_params(weights = glm_prior_params) # initialize the model parameters
    lls_all[i,:],A_all[i,:,:],w_all[i,:,:],pi0 = GLMHMM.fit(y,X,A_init,w_init) # fit the model
    print('initialization %s complete' %(i+1))

In [None]:
viz.plot_model_weights_states(MOUSE, w_all, lls_all, GLMHMM, bias=BIAS)