In [None]:
import os
from os import path as op

import numpy as np
import pandas as pd
from scipy import stats
from scipy import linalg

import vlgp
from vlgp import util, simulation

import matplotlib as mpl
%matplotlib inline
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
plt.style.use('dark_background')

import pickle

### Load vlgp data

In [None]:
with open('data/vlgpTrials.pickle', 'rb') as filename:
    trials = pickle.load(filename)
    
behav = pd.read_csv('data/RT_cue_choice.csv',header=None, names=['RT','cue','choice'])

In [None]:
# add a movement onset column
behav['mvmt'] = behav['RT'] + 400

In [None]:
behav.head()

### Apply vlgp to all trials with 4 latent dimensions

In [None]:
np.random.seed(0) # for reproducibility

# get 100 random trials
possible = np.arange(len(trials))
np.random.shuffle(possible)
picks = possible[:100]
temp = []
for i in picks:
    temp.append(trials[i])
trials = temp

# vlgp
fit = vlgp.fit(
    trials,  
    n_factors=4,  # dimensionality of latent process
    max_iter=20,  # maximum number of iterations
    min_iter=10  # minimum number of iterations
)

### Plotting

In [None]:
# pickle data
with open('data/vlgp_fit_4_dim.pickle', 'rb') as filename:
    fit = pickle.load(filename)

In [None]:
fit

In [None]:
trials = fit['trials']

In [None]:
# plot latent trajectories and averaged trajectory (all time points)

f = plt.figure(figsize=(15, 10))
ax1 = f.add_subplot(3, 2, 1, projection='3d')

ax1.set_title('Latent dynamics')
ax1.set_xlabel('Dim 1')
ax1.set_ylabel('Dim 2')
ax1.set_zlabel('Dim 3')
collection = []
for trial in trials:
    ax1.plot(trial['mu'][0], trial['mu'][1], trial['mu'][2], '-', lw=0.5, c='C0', alpha=0.5)
    collection.append(trial['mu'])
    
# average_trajectory = np.mean(trajectories_all, axis=0)
# ax1.plot(average_trajectory[0], average_trajectory[1], average_trajectory[2], '-', lw=linewidth_trial_average, c=color_trial_average, label='Trial averaged trajectory')
# ax1.scatter(average_trajectory[0][checkOnIdx], average_trajectory[1][checkOnIdx], average_trajectory[2][checkOnIdx], s=100, c='C5')
ax1.xaxis.set_ticklabels([])
ax1.yaxis.set_ticklabels([])
ax1.zaxis.set_ticklabels([])
ax1.grid(False)
ax1.legend(loc='best')


In [None]:
# get fast and slow trials for plotting
quantiles = behav['RT'].quantile([0.25,0.75])
fastRTs = behav[behav['RT'] <= quantiles[0.25]]
fastRTidx = list(fastRTs.index)
slowRTs = behav[behav['RT'] >= quantiles[0.75]]
slowRTidx = list(slowRTs.index)


slowTrials = []
for trial in trials:
    if trial['ID'] in slowRTidx:
        slowTrials.append(trial)
        
fastTrials = []
for trial in trials:
    if trial['ID'] in fastRTidx:
        fastTrials.append(trial)    