<a href="https://colab.research.google.com/github/VladGrigoras/CNS/blob/main/Blank_Lab5_Neural_Data_Analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>



---

**If you are on Google Colab, before running anything, you may wish to connect to a GPU environment**

---



#Introduction

In this lab, you will get hands-on experience with neural data. 

The dataset consists of recordings of the parietal cortex of a monkey during a sensory-motor task. It is one of the datasets from the [Neural Latent Benchmark](https://neurallatents.github.io/) project, whose aim is to provide standardized freely accessible neural datasets for the sake of developing data analysis and modeling methods.



---



The monkey starts a *trial* of the experiment holding a manipulandum (i.e. a joystick). The manipulandum, and therefore the hand of the animal, is then "bumped". The task of the animal is to return the manipulandom to its original central position. The bump can be directed toward one of 8 angles ($0^\circ,45^\circ...,315^\circ$). This is repeated for hundreds of trials. Throughout, the position of the hand of the animal is recorded.

![](https://drive.google.com/uc?export=view&id=1Jyt8sjdagyBY1_kl-rTJhTFHlvdmFRLM)

Neurons from area A2 are also recorded using a [microelectrode array](https://www.brainlatam.com/uploads/produto/utah-array-335.webp). Area A2 is located in the somatosensory cortex (parietal cortex) which carries computations related to tactile senses (e.g. feeling the pressure on your fingers when you grab an object). More recent work has suggested that its computations are not purely limited to tactile sense [1], and that, among other things, kinematics (e.g. movement of the arm) is also represented in the neural activity of area A2. 

![](https://drive.google.com/uc?export=view&id=14ZWPqYTF5L33umDcmDCWoWo0WwUV_GAo)




---



Today we will explore this hypothesis from a phenomenological standpoint by repurposing to this sensory area some of the models classically used in the motor literature. In particular, when it comes to the modeling or data analysis of neural activity related to arm movement, linear dynamical systems are extensively used. The mathematical theory will build upon what was introduced in lab 0, except that you will now *fit* those linear dynamical systems to neural data.

#Data imports

The neural data is presented in the Neurodata Without Border (NWB) format, which is a classically used format for storing and sharing neural data. Here all the code was written for you, you may just run the following cells which will download the part of the dataset from this experiment which is relevant to this lab: one recording session. Note that this can take a while (<5min) to run. So you should avoid restarting your runtime environment or rerunning those cells. 

### Run this:

(no need to expand if you don't want to clutter your ipynb)

In [None]:
import seaborn as sns
sns.set()

# global defaults for plots - optional
sns.set_theme(style="ticks",
              palette="Set2",
              font_scale=1.0,
              rc={
              "axes.spines.right": False,
              "axes.spines.top": False,
          },
          )

In [None]:
# You should not evaluate this cell multiple times

from IPython.display import clear_output

# Download the data
!pip install git+https://github.com/neurallatents/nlb_tools.git
!pip install dandi
!dandi download https://gui.dandiarchive.org/dandiset/000127

#Install Jax optimization library
!pip install optax

#Install behavior decoding library
!pip install git+https://github.com/arthur-pe/hand-decoding.git#egg=hand_decoding

clear_output() #clear cell output

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
import itertools
import scipy

import jax.numpy as jnp
import jax.random as random
import jax
import optax

import hand_decoding

from nlb_tools.nwb_interface import NWBDataset

np.random.seed(7)

plt.rcParams['figure.dpi'] = 120

In [None]:
#The part of the dataset of interest
dataset = NWBDataset("./000127/sub-Han/", "*train", split_heldout=False)

In [None]:
#Function to import the dataset -- note that we do no use the smoothing argument
def get_data(smoothing=70):
  dataset.smooth_spk(smoothing, name='smth_40')
  continuous_task_variables = ['force', 'hand_pos', 'hand_vel', 'joint_ang', 
                              'joint_vel', 'muscle_len', 'muscle_vel', 'spikes', 
                              'spikes_smth_40', 'condition', 'direction']
  all_task_variables = continuous_task_variables
  task_variables_dict = {key: [] for key in all_task_variables}

  for id, cond in enumerate(unique_conditions):
      cond_mask = (np.all(dataset.trial_info[['ctr_hold_bump', 'cond_dir']] == cond, axis=1)) & \
                  (dataset.trial_info.split != 'none')
      cond_data = dataset.make_trial_data(align_field='move_onset_time', align_range=(-100, 500), ignored_trials=~cond_mask)

      for idx, trial in cond_data.groupby('trial_id'):
        for var in all_task_variables[:-2]:
          cond_tensor = trial[var]
          task_variables_dict[var].append(cond_tensor)
        task_variables_dict['condition'].append(cond[0])
        task_variables_dict['direction'].append(cond[1])

  for var in all_task_variables:
    task_variables_dict[var] = np.stack(task_variables_dict[var])

  return task_variables_dict

In [None]:
bump = [True] #[False] for voluntary movements, [False, True] for both
angle = [i*45.0 for i in range(8)]
unique_conditions = list(itertools.product(bump, angle)) #cartesian product

task_variables_dict = get_data(smoothing=40) #takes a while

In [None]:
#rebinning
bin_size = 10
neural_data = np.stack([task_variables_dict['spikes'][:,i:i+bin_size].sum(axis=1) 
      for i in range(0,task_variables_dict['spikes'].shape[1],bin_size)],axis=1)
hand_movement = np.stack([task_variables_dict['hand_pos'][:,i:i+bin_size].mean(axis=1) 
      for i in range(0,task_variables_dict['hand_pos'].shape[1],bin_size)],axis=1)

#Task variables
angle = task_variables_dict['direction']
bump = task_variables_dict['condition']

trial_dimension, time_dimension, neuron_dimension = neural_data.shape

### Data format

The data thus imported are numpy arrays of shape:


```
neural_data (trial, time, neuron) : (364, 600, 65)
hand_movement (trial, time, position) : (364, 600, 2)
angle (trial) : (364)
```

*   `neural_data` contains binned spikes: that is for a given trial and neuron, an array of the form [0,2,0,0,1,...] indicating the number of spikes that occured between 0-10ms, 10-20ms, ... The bump happens at 100ms, that is bin number 10.

* `hand_movement` contains the x-y position of the hand of the animal (or more precisely the manipulandum's position), for each trial and over time.

*   `angle` contains the angle of the bump (in degrees) for each trial.

#Simple statistics

A common first step when analyzing a new piece of data is to do basic statistics and plotting to get a sense of the data: "are there near-silent neurons?", "is the behavior stereotypic?", "are trials of similar durations?" are typical questions which might be relevant to answer before starting applying complex models or data analysis methods to the data. 

## Analysis of behavior

First, we plot the hand movement over trials, colored by the angle of the bump. As expected, the hand starts at $(x,y)=(0,0)$, is deflected toward the direction of the bump, and the animal pulls it back to $(0,0)$.

In [None]:
fig = plt.figure(figsize=(4,4), constrained_layout=True)
ax = fig.add_subplot()

cmap = matplotlib.cm.get_cmap('gist_rainbow')

for id, pos in enumerate(hand_movement):
  ax.plot(pos[:,0], pos[:,1], c=cmap(angle[id]/360), alpha=0.6, linewidth=1)

ax.set_xlabel('x (cm)')
ax.set_ylabel('y (cm)')
ax.set_title('Hand movement')
plt.show()

Next we plot the average (over trials) velocity of the hand for each bump direction. To not clutter the plot, we only plot the average velocity for 3 bump directions. The shade represents 1 standard deviation.

In [None]:
fig = plt.figure(figsize=(4,4), constrained_layout=True)
ax = fig.add_subplot()

cmap = matplotlib.cm.get_cmap('gist_rainbow')

ts = np.arange(-10, time_dimension-10)*10

for a in np.unique(angle)[-3:]:
  hand_movement_of_angle = hand_movement[angle==a]
  velocity = np.linalg.norm(hand_movement_of_angle, axis=-1)
  ax.plot(ts, velocity.mean(axis=0),
          c=cmap(a/360), alpha=0.6, linewidth=1, label='$'+str(a)+'^\circ$')
  
  std_velocity = velocity.std(axis=0)#/len(velocity)
  ax.fill_between(ts, velocity.mean(axis=0)-std_velocity, velocity.mean(axis=0)+std_velocity,
          color=cmap(a/360), alpha=0.1)

ax.axvline(0, color='red', label='bump', linestyle='--')

ax.legend()

ax.set_xlabel('time (ms)')
ax.set_ylabel('velocity (cm/ms)')
ax.set_title('Average hand velocity')
plt.show()

*  The bump seems to have less effect when directed toward the animal (dark blue). In particular, the animal returns to the central position faster as represented by the velocity returning to near zero at an earlier time. How may that affect the analysis of the neural data?

## Analysis of the neural data



### Neuron-wise average

We plot a histogram of the average number of spikes per trial for all neurons.

In [None]:
fig = plt.figure(figsize=(4,4), constrained_layout=True)
ax = fig.add_subplot()

ax.hist(neural_data.sum(axis=1).mean(axis=0), bins=20)

ax.set_xlabel('average number of spikes'), ax.set_ylabel('number of neurons')

plt.show()

*  What do you observe?

### Exercise 1 $-$ Time-wise average

Plot the average (over neurons and trials) number of spikes as a function of time.

In [None]:
#Your solution:

*  What do you observe?

### Tuning

Next we will explore whether neurons are tuned to bump directions. To this end, we will sort the the trials according to reach direction, and average over the time period from -50ms to 200ms the activity of three example neurons.

In [None]:
Index_of_neurons_plotted = [0, 1, 2]

fig, axes = plt.subplots(1, len(Index_of_neurons_plotted), 
              figsize=(3*len(Index_of_neurons_plotted),3), constrained_layout=True)

cmap = matplotlib.cm.get_cmap('gist_rainbow')

unique_angle = np.unique(angle)

sorted_id = np.argsort(angle)
sorted_angle = angle[sorted_id]

for i in Index_of_neurons_plotted:
  for a in unique_angle:
    axes[i].scatter(a, neural_data[angle==a,5:30].sum(axis=1).mean(axis=0)[i], color=cmap(a/360))
  axes[i].set_xlabel('angle (degree)')

axes[0].set_ylabel('average number of spikes\n per trial')
  
plt.show()


*  Qualitatively comment on the tuning of these three neurons (you may want to look at other neurons).

As it turns out, while the activity of single neurons is poorly tuned to bump direction, the activity of the population of neurons as a whole is highly stereotypical for a given bump direction. We will therefore build a model of the population rather than of single neurons. But before using this model we need to verify that it is a valid one based on simpler statistics.

### Exercise 2 $-$ Fano factor

As a reminder, the Fano factor is the ratio of the variance to the mean of the data. 

$$
F = \frac{\sigma^2}{\mu}
$$

We will compute the Fano factor in two ways. First for each neuron, then for each time point.

*  Compute the Fano factor over all trials and time for each neuron separately, and plot a histogram of the obtained Fano factors.

In [None]:
#Your solution:

### Running Fano factor

We will next compute the running Fano factor over the whole dataset. That is the Fano factor computed over a given time window $t$ to $t+\Delta t$ms (here pick a $30$ms $=3$bins window), for all $t$.

In [None]:
#Running Fano factor
window_size = 3

running_fano_factor = []
for bin in range(time_dimension-window_size):
  mean = neural_data[:,bin:bin+window_size].mean()
  var = neural_data[:,bin:bin+window_size].var()
  running_fano_factor.append(var/mean)

fig = plt.figure(figsize=(4,4), constrained_layout=True)
ax = fig.add_subplot()

ax.axvline(0, label='bump', color='red', linestyle='--')
ax.legend()
ax.plot(np.arange(-10, time_dimension-window_size-10)*bin_size, running_fano_factor)

ax.set_xlabel('time (ms)'), ax.set_ylabel('running Fano factor')
plt.show()

*  Given what you have seen in class, what stochastic process may or may not be suited to model this data?

# Latent linear dynamical system model

Next, we will fit a linear dynamical system to the neural data [2]. Rather than fitting a dynamical system of the dimension of the number of neurons, we will try to find *latent* variables in the neurons recorded. The word *latent* is used in two ways in neuroscience:

*  To describe a small set of variables whose (linear) combination describes well the statistics of the population of neurons. 
*  To describe unobserved variables in probabilistic models. 

Both interpretations are combined in the model we will use here. First from the fact that our LDS is low dimensional. Second from the fact that we will model spikes $-$ which is required since the data are spikes $-$ using an (inhomogenous) Poisson process whose rate is a linear combination of the variables of the LDS. 

These latent variables will provide hints as to the computations that may be happening in the neural population recorded. For example, they will turn out to be correlated with the animal's behavior. 

---


For any given trial $i$, and a latent variable $\mathbf{x}^{(i)}(t)\in\mathbb{R}^m$ the dynamical system will be evaluated as,

\begin{aligned}
\frac{d}{dt}\mathbf{x}^{(i)}=W\mathbf{x}^{(i)}, &&\mathbf{x}^{(i)}(0)=\mathbf{x}_0^{(i)}.
\end{aligned}

That is, the *weight matrix* $W$ of the system is shared across all trials, but the initial condition is allowed to vary. We will furthermore constrain the initial condition to be the linear combination of *basis initial conditions*,

\begin{aligned}
\mathbf{x}_0^{(i)} = \sum_{i=j}^r U_{ij}\mathbf{v}_j.
\end{aligned}

where $U\in\mathbb{R}^{K\times r}$ where $k$ is the number of trials and $r$ the number of basis initial conditions. Here, as a guess, we choose $r=8$ since there are $8$ bump directions and the neural activity should be relatively stereotypic for a given bump direction. In practice, one would find the optimal value of $r$ and $m$ through [cross-validation](https://en.wikipedia.org/wiki/Cross-validation_(statistics).

As you've seen in first week of lab, this system has solution,

\begin{aligned}
\mathbf{x}^{(i)}(t) = e^{tW}\mathbf{x}_0^{(i)}=\sum_{j=1}^r u_j^{(i)}e^{tW}\mathbf{v}_j.
\end{aligned}

and thus the activity of neurons during a given trial can be expressed as the linear combination of $r$ bases activities $e^{tW}\mathbf{v}_j$. 

We ought to consider $\mathbf{x}^{(i)}(t)\in\mathbb{R}^{m}$ (where $m < n$) as the *latent activity* of the population of neuron recorded. We therefore need to introduce a map $M\in \mathbb{R}^{n \times m}$ from this state to the full neural state,

\begin{aligned}
\mathbf{\hat y}^{(i)}(t)=\exp(M\mathbf{x}^{(i)}(t))
\end{aligned}

where $\mathbf{\hat y}^{(i)}(t)$ is the estimate of the firing rate of the neurons recorded at time $t$ during trial $i$. The element-wise exponential is there to ensure that the firing rates are positive.

---

The variable $\mathbf{\hat y}$ is continuous, but we only have access to neurons' spikes. We can however maximize the likelihood of the observed spikes given the estimated firing rate. For this, we can use the results of the previous section and notice that a Poisson process is a decent model for the spikes. It will be an inhomogeneous Poisson process, whose parameter is $\mathbf{\hat y}$. Such a model is called a PoissonLDS (or PLDS). 

Since for a Poisson process the number of spikes in any time interval is independent of the number of spikes in other disjoint time intervals, we can just take the product of the probability to obtain the likelihood of the whole dataset given the estimated rates,


\begin{aligned}
p(n | \mathbf{\hat y}):&=\prod_{i=1}^{K}\prod_{j=1}^{T}\prod_{k=1}^{N} p(n^{(i)}_{jk}|\mathbf{\hat y}^{(i)}_{jk})
\\&=\prod_{i=1}^{K}\prod_{j=1}^{T}\prod_{j=1}^{N} e^{\mathbf{\hat y}^{(i)}_{jk}} \frac{(\mathbf{\hat y}^{(i)}_{jk})^{n^{(i)}_{jk}}}{n^{(i)}_{jk}} 
\end{aligned}

where $n^{(i)}_{jk}$ is the number of spikes observed in the $j$th bin of the $i$th trial for neuron number $k$, and $\mathbf{\hat y}^{(i)}_{jk}$ the corresponding estimated rate (we might for example take the mid time point of the bin). It is difficult to optimize directly on probabilities. In particular since the $p(n^{(i)}_{jk}|\mathbf{\hat y}^{(i)}_{jk})$'s are in $[0,1]$, their product is minuscule. We instead take the minus the log of this probability and divide by the number of trials and time points, 

\begin{aligned}
L(W,\mathbf{v}_j, \mathbf{u}^{(i)}, M) = \frac{1}{KT}\sum_{i=1}^{K}\sum_{j=1}^{T}\sum_{k=1}^{N} -\log \left( e^{\mathbf{\hat y}^{(i)}_{jk}} \frac{(\mathbf{\hat y}^{(i)}_{jk})^{n^{(i)}_{jk}}}{n^{(i)}_{jk}} \right)
\end{aligned}

Since the negative log function is strictly decreasing, maximizing $p(n | \mathbf{\hat y})$ is equivalent to minimizing $L$. To perform this minimization we can use various algorithms. The simplest one is gradient descent which can be implemented easily with machine learning libraries such as Torch, Jax or Tensorflow. 

---

This model might at first seem complex, but its elements allow extracting what is needed to better understand the data. First, it allows going from an arbitrarily large neural population to a small set of *latent* variables which explains it well. Second, it allows going from a large number of trials to a small set of basis trials which describes well all trials. And finally it allows switching from discrete variables (spikes) to continuous ones. These features allow exploring different aspects of the data as we will see bellow.


**Optional asides.**  


---

*Tensors*


The tensor $[e^{Wt}\mathbf{v}_1, e^{Wt}\mathbf{v}_2, ...]\in \mathbb{R}^{r \times m \times T}$ is the core tensor of a [tucker decomposition](https://en.wikipedia.org/wiki/Tucker_decomposition) of the estimated firing rate before applying the exponential. The matrices $M$ and $U$ are maps,

$$
\mathbb{R}^{r \times m \times T}\xrightarrow{M}\mathbb{R}^{r \times N \times T}\xrightarrow{U}\mathbb{R}^{K \times N \times T}\xrightarrow{\text{exp}}\mathbb{R}^{K \times N \times T}_{+},
$$

Or informally,

$$
\\e^{Wt}\mathbf{v}_j\xrightarrow{M}Me^{Wt}\mathbf{v}_j\xrightarrow{U,\text{exp}}\mathbf{y}^{(i)}_{j},
$$

which commute by multilinearity,

$$
\mathbb{R}^{r \times m \times T}\xrightarrow{U}\mathbb{R}^{K \times m \times T}\xrightarrow{M}\mathbb{R}^{K \times N \times T}\xrightarrow{\text{exp}}\mathbb{R}^{K \times N \times T}_{+}.
$$

$$
\\e^{Wt}\mathbf{v}_j\xrightarrow{U}\mathbf{x}^{(i)}_j\xrightarrow{M,\text{exp}}\mathbf{y}^{(i)}_j,
$$

This implies that we can apply either $M$ and $U$ to get latent neural activity $\mathbf{x}^{(i)}(t)$ or *latent trial activity* $Me^{Wt}\mathbf{v}(t)$. Low dimensional latent neuron activity usually better describes sensory areas while low dimensional trial activity usually better describes motor areas [3]. This model which allows flexibly varying the dimension of the two is therefore particularly well suited for this sensory-motor dataset.

---
*Gaussian observations*


The model considered above more usually contains a *Gaussian observation* such that $\ln (\mathbf{y}^{(i)}(t))\sim \mathcal{N}(M\mathbf{x}^{(i)}(t), \boldsymbol{\Sigma})$ (see [2]). This step requires using the [expectation-minimization algorithm](https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) or other variational methods to fit the model to data. However, these algorithms can be computationally expensive and can make it more difficult for the model to converge. We therefore do not consider this case here. 

## Fitting 

To fit the model to neural data we use [Jax](https://jax.readthedocs.io/), which is backend for numpy which allows using numpy as-is while being able to compute gradients of operations done on numpy array. 

In [None]:
@jax.jit
def lds(params):

  Wt = jnp.einsum('i,jk->ijk', ts, params['W']) #W*t, time x latent x latent

  Wt_exp = jax.vmap(jax.scipy.linalg.expm)(Wt) #e^(W*t), time x latent x latent

  #ue^(W*t)v, trial x time x latent
  x = jnp.einsum('ij,jl,mlo->imo', params['U'], params['V'], Wt_exp) 

  return x

def poisson_log_likelihood(rates, spikes):
  ratess = rates/100 #spikes are in 10ms bins, rates is in spikes/second
  likelihood = jnp.exp(-rates)*jnp.power(rates,spikes)/scipy.special.factorial(spikes)
  return jnp.log(likelihood)

def L(params, y):

      npll = -poisson_log_likelihood(jax.nn.softplus(lds(params) @ params['M'].T), y)

      mean_loss = npll[np.random.rand(trial_dimension)<0.5].mean()

      return mean_loss

In [None]:
def poisson_LDS_training(y, params, ts, opt_fn, opt_state, steps=100):

    losses = []
    for step in range(steps):
        loss, grads = jax.value_and_grad(L)(params, y)
        updates, opt_state = opt_fn(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        losses.append(loss)

        if step%100==0:
          print('SGD step:', step, ' loss:', loss)

    return jnp.stack(losses), params, np.array(lds(params)), opt_state

In [None]:
#Hyperparameters
m = 30
r = 8

#We discard -100ms to 100ms as there is a lag in neural activity
new_time_dimension = time_dimension - 20 
y = neural_data[:,20:]

#Duration of evaluation of the LDS
ts = jnp.linspace(0,new_time_dimension/10,new_time_dimension)

#Jax requires a key (seed) for every random variable
rng = random.PRNGKey(0)
keys = random.split(rng, 3)

#Optimized parameters
W = random.normal(keys[0], (m, m))/np.sqrt(m) - jnp.eye(m)*3/np.sqrt(m)
U = random.normal(keys[0], (trial_dimension, r))/np.sqrt(r)
V = random.normal(keys[2], (r, m))/np.sqrt(m)
M = random.normal(keys[2], (neuron_dimension, m))/np.sqrt(m)

params = {'W' : W, 'U' : U, 'V' : V, 'M' : M}

adam = optax.adamw(learning_rate=0.01)

losses, params, x, _ = poisson_LDS_training(y, params, ts, steps=3000, 
                          opt_fn=adam.update, opt_state=adam.init(params))

clear_output()
fig = plt.figure(figsize=(4,4))
ax = fig.add_subplot()
ax.plot(losses)
ax.set_title('Loss')
ax.set_xlabel('iteration')
ax.set_ylabel('Poisson negative log likelihood')
plt.show()

## Hand decoding

A really exciting aspect of capturing neural activity with latent variables is that these latent variables are often correlated with behavior. Here, we illustrate this by linearly mapping the latent activity onto hand movement. That is to find the matrix $A\in \mathbb{R}^{m \times 2}$ such that,

$$
\sum_i^K ||\mathbf{x}^{(i)}A - h^{(i)}||^2
$$

is minimized, where $h^{(i)}\in \mathbb{R}^{T\times 2}$ is the hand position over time at trial $i$.

In [None]:
#note: the hand-decoding library takes time x trial x neuron tensors
#you may interchange velocity_decoding with position_decoding
decoded_hand_pos = hand_decoding.velocity_decoding(x.transpose(1,0,2), 
    hand_movement[:,10:-10].transpose(1,0,2), regularization=1)[0].transpose(1,0,2)


print('R^2:', hand_decoding.trial_wise_r2(hand_movement[:,10:-10].transpose(1,0,2), 
                                          decoded_hand_pos.transpose(1,0,2))[0])

fig, axes = plt.subplots(1, 2, figsize=(8,4), tight_layout=True)

for id, pos in enumerate(decoded_hand_pos):
  axes[0].plot(pos[:,0], pos[:,1], 
      c=cmap(angle[id]/360), alpha=0.5, linewidth=1)

for id, pos in enumerate(hand_movement[:,:-10]):
  axes[1].plot(pos[:,0], pos[:,1], 
      c=cmap(angle[id]/360), alpha=0.5, linewidth=1)

#Plotting stuff
axes[0].set_xlabel('x (cm)'), axes[0].set_ylabel('y (cm)')
axes[0].set_title('Estimated hand movement')

axes[1].set_xlabel('x (cm)')
axes[1].set_title('True hand movement')

plt.setp(axes,  xlim=(np.min(hand_movement[:,:,0]*1.1), np.max(hand_movement[:,:,0]*1.1)),
                ylim=(np.min(hand_movement[:,:,1]*1.1), np.max(hand_movement[:,:,1]*1.1)))

plt.show()


Interesting, right? In fact, given sufficient data (not much more than this dataset), latent variables that generalize across recording sessions or even animals can be fitted. That is, only $U,M$ need to be refit to new trials or neurons, for which there are analytic solutions since they act only linearly on the latent variables.

*  What does this suggest regarding the role of area A2 in kinematics of the hand?

*  What stronger experiment might you do to probe the relationship between hand kinematics and area A2.

*  How might you improve on this model? (e.g. with respect to the variable timing of movement you saw earlier)

*  What might be a practical application of such a model?

## Latent activity

We now wish to understand *how* the model implements this trial-to-trial variability that seems to be correlated with behavior.

To get a sense of what the latent activity looks like, we plot the first three latent variables for all trials, that is $\mathbf{x}^{(i)}_j(t)$, $i\in[364], j\in [3]$. Furthermore, we color each curve according the the angle of the bump in that trial.

In [None]:
number_of_latents_plotted = 3

fig, axes = plt.subplots(1, number_of_latents_plotted, 
              figsize=(1+3*number_of_latents_plotted,number_of_latents_plotted),
              constrained_layout=True)

ts = np.arange(100,500,10)

for j in range(number_of_latents_plotted):
  axes[j].set_title('$\mathbf{x}^{(i)}_{'+str(j+1)+'}$')
  for id, xs in enumerate(x):
    axes[j].plot(ts, xs[:,j], c=cmap(angle[id]/360), alpha=0.4, linewidth=1)
    axes[j].set_xlabel('time (ms)')

fig.suptitle('Latent variables')
plt.show()

*  What do you observe? (hint: colors)
*  How might you check that the trial-to-trial variability for a given bump direction is lower than the bump direction-to-bump direction variability?


##PCA

However, looking at the latent variables one at a time is quite arbitrary. Instead, we can project them in a 3-dimensional space. A common choice for such a projection is the one that maximizes the variance explained of the data. Or equivalently, the one that minimizes the error when projecting back to the full space. That is $V\in \mathbb{R}^{m\times 3}$ orthogonal such that,

$$
\sum_i^K||\mathbf{y}^{(i)}-V^{T}V\mathbf{x}^{(i)}||^2
$$

is minimized. This is called [principal component analysis (PCA)](https://en.wikipedia.org/wiki/Principal_component_analysis).

In [None]:
U, S, V = scipy.linalg.svd(x.reshape(-1,m), full_matrices=False)

x_on_pc = x @ V.T[:,:3]

fig = plt.figure(figsize=(4,4), constrained_layout=True)
ax = fig.add_subplot(projection='3d')

for id, xs in enumerate(x_on_pc):
    ax.plot(xs[:,0], xs[:,1], xs[:,2], c=cmap(angle[id]/360), alpha=0.4, linewidth=1.2)

ax.set_xlabel('PC1'), ax.set_ylabel('PC2'), ax.set_zlabel('PC3')
plt.show()

*  Compared to looking at a single latent variable at a time, what do you observe here?

*  Both $A$ derived earlier and the $V$ considered here are simply linear projections. How do their objective differ from:
  *  a mathematical standpoint?
  *  a neuroscientific standpoint?


## Initial condition

Trajectories of dynamical systems with no input are purely dependent on their initial condition. Therefore, if trajectories of the model seem bump direction-specific, we also expected the initial condition to be. To explore this, we plot $U$$-$which as a reminder is of shape $K \times R$ where $K$ is the number of trials and $R$ the number of basis initial conditions$-$sorted by bump directions.

In [None]:
number_of_Us_plotted = 3

U = np.array(params['U']) #Jax to np

fig, axes = plt.subplots(number_of_Us_plotted, 1,
              figsize=(4, 2*number_of_Us_plotted),
              constrained_layout=True)

sorted_id = np.argsort(angle)
sorted_angle = angle[sorted_id]

U_sorted = U[sorted_id]

for i in range(number_of_Us_plotted):
  for a in np.unique(angle):

    axes[i].scatter(sorted_id[sorted_angle==a], U_sorted[sorted_angle==a,i],
                    color=cmap(a/360), s=10)
    
    axes[i].set_ylabel('$U_'+str(i)+'$')
      
axes[0].set_title('Coefficient of basis initial conditions')
axes[-1].set_xlabel('trial')

plt.show()

*  What do you observe? 
*  How does this compare to the tuning of single neurons that you explored earlier?

## Analytical standpoint

Finally, we may wish to get insight into the fitted model by analytically describing the dynamical system. While there are many approach one could take, here we take the simplest one: we look at the eigenspectrum of $W$.

In [None]:
fig = plt.figure(figsize=(4,4), constrained_layout=True)
ax = fig.add_subplot()

L_W, _ = scipy.linalg.eig(W)
ax.scatter(L_W.real, L_W.imag, label='pre-training')

L_W, _ = scipy.linalg.eig(params['W'])
ax.scatter(L_W.real, L_W.imag, label='post-training')

ax.set_xlabel('Re'), ax.set_ylabel('Im')
ax.legend()
ax.axvline(0, color='black', linewidth=0.8)
ax.axhline(0, color='black', linewidth=0.8)

plt.show()

*  Pick a few eigenvalues of interest and comment on their meaning. 1) From a mathematical standpoint, and 2) from the standpoint of the population of neurons the model is trying to capture.

# Conclusion

We started with a hypothesis that phenomenological models typically applied to recordings of the motor cortex could be repurposed to area A2. This was motivated by the findings of [1] that kinematics might be represented in the activity of area A2. As a first step, we explored the behavioral and neural data to get a better understanding of them and to decide whether our model of choice, PoissonLDS, was relevant at all. Then, we fitted that model to neural data and observed that it captured variability in the data that was relevant to behavior. Indeed, hand movements could be relatively well linearly decoded from the trajectories of the model, and the initial condition of the model seemed tuned to angles. Finally, we hinted at the fact that the model, unlike the data, could be *mathematically* analyzed. These steps, starting at our original hypothesis, might be a good glimpse into research involving neural recordings, would you decide to ongo down that path. 

# References

[1] Chowdhury, R. H., Glaser, J. I., & Miller, L. E. (2020). Area 2 of primary somatosensory cortex encodes kinematics of the whole arm. Elife, 9.

[2] Macke, J. H., Buesing, L., Cunningham, J. P., Yu, B. M., Shenoy, K. V., & Sahani, M. (2011). Empirical models of spiking in neural populations. Advances in neural information processing systems, 24.

[3] Seely, J. S., Kaufman, M. T., Ryu, S. I., Shenoy, K. V., Cunningham, J. P., & Churchland, M. M. (2016). Tensor analysis reveals distinct population structure that parallels the different computational roles of areas M1 and V1. PLoS computational biology, 12(11), e1005164.