# part0: imports

In [1]:
import os, sys, pathlib
from pprint import pprint 
from importlib import reload
import logging
logging.basicConfig(level=logging.ERROR)
import warnings
warnings.simplefilter("ignore")


import pandas as pd
import numpy as np
import xarray as xr
from sklearn.decomposition import PCA
import scipy.linalg as linalg

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib
from matplotlib.ticker import MaxNLocator


from tools import utilityTools as utility
from tools import dataTools as dt
import pyaldata as pyal

%matplotlib inline

root = pathlib.Path("/data")

# Compare different epochs

the idea is to see whether canonical axes between 2 animals provide a higher VAF for time epochs in the trial that they have not been trained on, compared to, for example, M1-PMd axes in a single animal.

Check the VAF for a single session as a test

In [2]:
animal= 'Chewie'
fname = root / animal / "Chewie_CO_CS_2016-10-21.mat"

df = pyal.mat2dataframe(fname, shift_idx_fields=True)

preprocessing

In [3]:
def get_target_id(trial):
    return int(np.round((trial.target_direction + np.pi) / (0.25*np.pi))) - 1

df = pyal.select_trials(df, df.result== 'R')
df = pyal.select_trials(df, df.epoch=='BL')
df = pyal.remove_low_firing_neurons(df, "M1_spikes", 1)
df = pyal.remove_low_firing_neurons(df, "PMd_spikes", 1)

df = pyal.add_firing_rates(df, 'smooth')

df["target_id"] = df.apply(get_target_id, axis=1)  # add a field `target_id` with int values



applying PCA for 300ms post movement onset

In [4]:
df_ = pyal.restrict_to_interval(df, start_point_name='idx_movement_on', rel_start=0, rel_end=30)

M1_rates = np.concatenate(df_.M1_rates.values, axis=0)
M1_rates -=np.mean(M1_rates,axis=0)

M1_model = PCA(n_components=20, svd_solver='full');
M1_model.fit(M1_rates);
df_ = pyal.apply_dim_reduce_model(df_, M1_model, 'M1_rates', 'M1_pca');


PMd_rates = np.concatenate(df_.PMd_rates.values, axis=0)
PMd_rates -=np.mean(PMd_rates,axis=0)

PMd_model = PCA(n_components=20, svd_solver='full');
PMd_model.fit(PMd_rates);
df_ = pyal.apply_dim_reduce_model(df_, PMd_model, 'PMd_rates', 'PMd_pca');

to make life easy, just limit everything to 1 target

In [5]:
df_= pyal.select_trials(df_, df_.target_direction ==0)

CCA on m1-pmd

In [6]:
d0 = np.concatenate(df_['M1_pca'].values, axis=0)
d1 = np.concatenate(df_['PMd_pca'].values, axis=0)

n_samples = min ([d0.shape[0], d1.shape[0]])
d0 = d0[:n_samples,:]
d1 = d1[:n_samples,:]

A, B, r, _, _ = dt.canoncorr(d0, d1, fullReturn=True)

In [7]:
print(f'the CCs are:{r}')

the CCs are:[0.94779339 0.87838562 0.83821908 0.73555619 0.70545989 0.66681631
 0.61031016 0.5606342  0.51480629 0.46384192 0.42750414 0.42011013
 0.32919582 0.29094555 0.28618524 0.23505218 0.22347631 0.11351897
 0.10931743 0.06073484]


---

From Bence's code [here](https://github.com/BeNeuroLab/pysubspaces/blob/4b7c7512ccdff8fdbd1334cb6c3acc8f5bdbf3d6/pysubspaces/subspaces.py#L29).

In [8]:
def variance_in_subspace(X, W):
    """
    Variance in a given subspace
    Parameters
    ----------
    X : 2D np.ndarray
        n_samples x n_features data array
    W : 2D np.ndarray
        n_components x n_features projection matrix
    Returns
    -------
    variance in the subspace
    """
    return np.trace(W @ np.cov(X.T) @ W.T)

In [9]:
variance_in_subspace (d0, A.T), variance_in_subspace (d1, B.T)

(20.000000000000004, 20.000000000000004)

weirdly specific?! but what is the total variance? what does it mean?
- can I shuffle the weights in the prrojection matrix and repeat mulriple times to estimate a band of VAF by chance?
    - maybe, check the matlab code

---

Based on the Nat Comm paper and some playing: 
$$\%VAF=\frac{norm(X)-norm(X-XC^TA^T(A^TA)^{-1}AC)}{norm(X)}$$
where:
- $A$ is the CCA output, projection matrix
- $C$ is the `PCA_model.components_`
- $X$ is the data matrix, $T\times n$ with $T$ time points and $n$ neurons, and each neuron has zero mean
- $norm$ is sum of squared elements

In [10]:
C = M1_model.components_
X = M1_rates
norm = lambda m:np.sum(m**2)

VAF = norm(X) - norm(X-X@C.T@A.T@linalg.inv(A.T@A)@A@C)
VAF /= norm(X)

VAF

-12.804745925211172