We can model our control/switching problem as a switching linear regression (w/HMM) model

In this problem, we want to model the output dynamics:
$x_t = Ax_{t-1} + Bu_{t-1} + N(0,sigma)$

We assume that the joystick/state dynamics are fixed.

That is, $A$ is time invariant

Therefore, we can use a switching linear regression by noting the residuals can be written as:
$x_t - Ax_{t-1} = Bu_{t-1} + N(0,sigma)$

This means, are output variables of the switching regression:
$y_t | x_t, z_t ~ N(H(z_t)x_t+F(z_t),Sigma)$
Where $H(z_t)$ are regression weights that change according to the state

can actually be written as:
$y_t = x_t - Ax_{t-1}$

In terms of kinematics, these difference terms equate to either velocity or velocity and acceleration in N-cartesian coordinates

*What matters is choosing the correct $x_{t-1}$ combinations.


# HMM Steps:
Initiate
Fit
Filter
Smooth
Decode

In [7]:
#Import libraries for HMM
import jax.numpy as jnp
import jax.random as jr

from itertools import count
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_context("notebook")
from sklearn.metrics import r2_score
import scipy
from dynamax.hidden_markov_model import LinearRegressionHMM
from dynamax.utils.plotting import CMAP, COLORS, white_to_color_cmap
#Import data handling of monkey pac-man
import PacTimeOrig.DataHandling as DH
import PacTimeOrig.Methods.utils as pacutils

#Import standard libraries
import numpy as np
import pandas as pd
import os
from scipy.io import loadmat
from scipy.io import savemat

In [8]:
# Helper functions for plotting
def plot_gaussian_hmm(hmm, params, emissions, states,  title="Emission Distributions", alpha=0.25):
    lim = 1.1 * abs(emissions).max()
    XX, YY = jnp.meshgrid(jnp.linspace(-lim, lim, 100), jnp.linspace(-lim, lim, 100))
    grid = jnp.column_stack((XX.ravel(), YY.ravel()))

    plt.figure()
    for k in range(hmm.num_states):
        lls = hmm.emission_distribution(params, k).log_prob(grid)
        plt.contour(XX, YY, jnp.exp(lls).reshape(XX.shape), cmap=white_to_color_cmap(COLORS[k]))
        plt.plot(emissions[states == k, 0], emissions[states == k, 1], "o", mfc=COLORS[k], mec="none", ms=3, alpha=alpha)

    plt.plot(emissions[:, 0], emissions[:, 1], "-k", lw=1, alpha=alpha)
    plt.xlabel("$y_1$")
    plt.ylabel("$y_2$")
    plt.title(title)
    plt.gca().set_aspect(1.0)
    plt.tight_layout()


def plot_gaussian_hmm_data(hmm, params, emissions, states, xlim=None):
    num_timesteps = len(emissions)
    emission_dim = hmm.emission_dim
    means = params.emissions.means[states]
    lim = 1.05 * abs(emissions).max()

    # Plot the data superimposed on the generating state sequence
    fig, axs = plt.subplots(emission_dim, 1, sharex=True)

    for d in range(emission_dim):
        axs[d].imshow(states[None, :], aspect="auto", interpolation="none", cmap=CMAP,
                      vmin=0, vmax=len(COLORS) - 1, extent=(0, num_timesteps, -lim, lim))
        axs[d].plot(emissions[:, d], "-k")
        axs[d].plot(means[:, d], ":k")
        axs[d].set_ylabel("$y_{{t,{} }}$".format(d+1))

    if xlim is None:
        plt.xlim(0, num_timesteps)
    else:
        plt.xlim(xlim)

    axs[-1].set_xlabel("time")
    axs[0].set_title("Simulated data from an HMM")
    plt.tight_layout()


def dat_simulator(most_likely_states,vel,inputs,params):
    '''Compute the predicted data under the model'''
    emit=params.emissions
    tmp=[]
    for i in range(len(vel)):
        tmp.append(np.dot(emit.weights[most_likely_states[i],:,:],inputs[i,:])+emit.biases[most_likely_states[i],:])

    tmp=np.stack(tmp)
    return tmp

In [10]:
# May need to try rescaling everything....
from tqdm import tqdm


# LOAD DATA STRUCTURE

dat = loadmat('/Users/user/PycharmProjects/PacManMain/data/Simulation/datout.mat')
accel=0
r2=np.zeros([5,8])
mcorr=np.zeros([5,8])
LLdat=np.zeros([5,8])
deltaW=np.zeros([5,8])
for switype in tqdm(range(5)):
    for trial in tqdm(range(8)):
        x=dat['datout'][0][0]['x'][trial][switype]
        A=dat['datout'][0][0]['A'][trial][switype]
        B=dat['datout'][0][0]['B'][trial][switype]
        shift=dat['datout'][0][0]['shiftfunc'][trial][switype]

        erA = A - x
        erB = B - x
        vel = np.array([np.gradient(x[:, 0], 1), np.gradient(x[:, 1], 1)]).transpose()

        if accel ==1:
            vel = np.array([np.gradient(vel[:, 0], 1), np.gradient(vel[:, 1], 1)]).transpose()

        erA = jnp.array(erA)
        erB = jnp.array(erB)

        vel = jnp.array(vel)
        inputs = np.hstack([erA, erB])
        keys=jr.PRNGKey(np.random.randint(1,100+1))
        #keys = map(jr.PRNGKey, count())
        hmm = LinearRegressionHMM(2, 4, 2)
        #test_params, param_props = hmm.initialize(next(keys))
        test_params, param_props = hmm.initialize(keys)

        test_params, lps = hmm.fit_em(test_params, param_props, vel, inputs=inputs,num_iters=300,verbose=False)

        most_likely_states = hmm.most_likely_states(test_params, vel, inputs=inputs)

        stateposterior = hmm.filter(test_params, vel, inputs=inputs)
        smoothposterior = hmm.smoother(test_params,vel,inputs=inputs)

        #Simulate data
        simvel=dat_simulator(most_likely_states,vel,inputs,test_params)

        fig, (ax1, ax2) = plt.subplots(1, 2)
        ax1.plot(shift[1,:].transpose())
        ax1.plot(smoothposterior.smoothed_probs, linestyle='dashed')
        plt.title('type:'+str(switype)+'_'+'trial:'+str(trial))
        ax2.plot(vel)
        ax2.plot(simvel,linestyle='dashed')
        fig.tight_layout()
        fig.savefig('/Users/user/PycharmProjects/PacManMain/data/HMMOUTPUT/'+'type:'+str(switype)+'_'+'trial:'+str(trial)+'.png')
        plt.close(fig)


        reshaped_X = vel.transpose().reshape(-1, 1)
        reshaped_Y = simvel.transpose().reshape(-1, 1)
        if np.sum(np.isnan(reshaped_Y))>0:
            r2[switype,trial]=np.nan
        else:
            r2[switype,trial] = r2_score(reshaped_Y, reshaped_X)


        #Find best matching posterior state to switch and correlate and square
        s1corr=np.corrcoef(smoothposterior.smoothed_probs[:,0],shift[1,:].transpose())[0,1]
        s2corr=np.corrcoef(smoothposterior.smoothed_probs[:,1],shift[1,:].transpose())[0,1]
        if np.sum(np.isnan(reshaped_Y))>0:
            mcorr[switype,trial]=np.nan
        else:
            mcorr[switype,trial]=np.max([s1corr,s2corr])

        if np.max([s1corr,s2corr])==s1corr:
            probout=smoothposterior.smoothed_probs[:,0]

            deltaW[switype,trial]=np.mean(np.power((smoothposterior.smoothed_probs[:,0]-shift[1,:].transpose()),2))
        else:
            probout=smoothposterior.smoothed_probs[:,1]
            deltaW[switype,trial]=np.mean(np.power((smoothposterior.smoothed_probs[:,1]-shift[1,:].transpose()),2))

        LLdat[switype,trial]=-lps[-1]
        mdic = {'probout':[probout]}
        DataFold='/Users/user/PycharmProjects/PacManMain/data/HMMOUTPUT/'
        savemat(DataFold+'/HMMprob'+str(switype)+'_'+str(trial)+'.mat',mdic)

  0%|          | 0/5 [00:00<?, ?it/s]
  0%|          | 0/8 [00:00<?, ?it/s][A
 12%|█▎        | 1/8 [00:00<00:05,  1.26it/s][A
 25%|██▌       | 2/8 [00:01<00:04,  1.32it/s][A
 38%|███▊      | 3/8 [00:02<00:03,  1.28it/s][A
 50%|█████     | 4/8 [00:03<00:03,  1.25it/s][A
 62%|██████▎   | 5/8 [00:04<00:02,  1.12it/s][A
 75%|███████▌  | 6/8 [00:05<00:01,  1.01it/s][A
 88%|████████▊ | 7/8 [00:06<00:01,  1.04s/it][A
100%|██████████| 8/8 [00:07<00:00,  1.04it/s][A
 20%|██        | 1/5 [00:07<00:30,  7.68s/it]
  0%|          | 0/8 [00:00<?, ?it/s][A
 12%|█▎        | 1/8 [00:00<00:05,  1.35it/s][A
 25%|██▌       | 2/8 [00:01<00:04,  1.35it/s][A
 38%|███▊      | 3/8 [00:02<00:04,  1.20it/s][A
 50%|█████     | 4/8 [00:03<00:03,  1.20it/s][A
 62%|██████▎   | 5/8 [00:04<00:02,  1.22it/s][A
 75%|███████▌  | 6/8 [00:04<00:01,  1.22it/s][A
 88%|████████▊ | 7/8 [00:05<00:00,  1.21it/s][A
100%|██████████| 8/8 [00:06<00:00,  1.22it/s][A
 40%|████      | 2/5 [00:14<00:21,  7.01s/it]
  0%

In [6]:
mdic = {'S1corr':[s1corr],'S2corr':[s2corr],'mcorr': [mcorr],'Rsquared': [r2],'LLdat' :[LLdat], 'deltaW':[deltaW]}
DataFold='/Users/user/PycharmProjects/PacManMain/data/HMMOUTPUT/'
savemat(DataFold+'/HMMOUTPUT.mat',mdic)

Array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan], dtype=float32)