In [1]:
# Standard libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interactive, IntProgress
from IPython.display import display

# Append base directory
import os,sys,inspect
rootname = "chernysheva-tmaze-analysis-2020"
thispath = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
rootpath = os.path.join(thispath[:thispath.index(rootname)], rootname)
sys.path.append(rootpath)
print("Appended root directory", rootpath)

from scipy.stats import mannwhitneyu, binom
rstest_twosided = lambda x, y : mannwhitneyu(x, y, alternative='two-sided')

# User libraries
from mesostat.metric.metric_non_uniform import MetricCalculatorNonUniform
from mesostat.stat.permtests import difference_test
from mesostat.utils.qt_helper import gui_fnames, gui_fpath
from mesostat.utils.arrays import numpy_merge_dimensions
from mesostat.utils.pandas_helper import outer_product_df, merge_df_from_dict

from src.lib.data_db import BehaviouralNeuronalDatabase
import src.lib.plots_lib as plots_lib
from src.lib.plots_pca  import PCAPlots
from src.lib.metric_wrapper import metric_by_selector
from src.lib.extra_metrics import num_non_zero_std, num_sample
from src.lib.significant_cells_lib import SignificantCells


%load_ext autoreload
%autoreload 2

Appended root directory /home/alyosha/work/git/chernysheva-tmaze-analysis-2020


In [2]:
# tmp_path = root_path_data if 'root_path_data' in locals() else "./"
params = {}
#params['root_path_data']  = gui_fpath("Path to data files", "./")
params['root_path_dff'] = '/media/alyosha/Data/TE_data/mariadata/dff/'
params['root_path_deconv'] = '/media/alyosha/Data/TE_data/mariadata/deconv/'
# params['root_path_dff'] = '/media/aleksejs/DataHDD/work/data/maria/dff/'
# params['root_path_deconv'] = '/media/aleksejs/DataHDD/work/data/maria/deconv/'

In [3]:
dataDB = BehaviouralNeuronalDatabase(params)

In [4]:
dataDB.read_neuro_files()

IntProgress(value=0, description='Read DFF Data:', max=27)

IntProgress(value=0, description='Read DECONV Data:', max=27)

In [None]:
dataDB.read_behavior_files()

IntProgress(value=0, description='Read Neuro Data:', max=27)

No trials found for Trial_LWhole_Mistake skipping
No trials found for Trial_RWhole_Mistake skipping
No trials found for Trial_LWhole_Mistake skipping


In [None]:
significantCellsSelectorDatatype = {}

for datatype in ['raw', 'deconv']:
    signCellsMaintenance = SignificantCells('significant_cells_'+datatype+'_mt.h5').get_cells_by_mouse()
    signCellsReward = SignificantCells('significant_cells_'+datatype+'_enc_reward.h5').get_cells_by_mouse()

    significantCellsSelectorDatatype[datatype] = {
        'None' : None,
        'Maintenance' : signCellsMaintenance,
        'Reward' : signCellsReward
    }

# 1. Clustering

Main idea:
* Select and cluster datapoints
* Color clustering by modality


Analysis Strategies:
1. Select 1 interval, color by modality $LR\otimes CM$
2. Select modality, color by interval

Dynamics Strategies:
1. Static-noob-1: One point per timestep. Bad because biased towards trials with more timesteps, slower transients
    * Nothing to be seen on TSNE
2. Static-noob-2: Mean value per trial. Balanced. Probably averages out important info
    * Nothing to be seen on TSNE
3. Dynamic-slow-legendre. Multiplex channels and legendre temporal basis. Good for stretch hypothesis
4. Dynamic-noob-1: Window-sweep time, multiplex onto channels. Bias+too-many-dimensions
5. Dynamic-noob-2: As above + preprocess dim.reduction. Bias

In [None]:
from sklearn import manifold, datasets
from sklearn.decomposition import PCA

def fit_color_data(ax, X, labels, methodName, param):
    if methodName == 'tsne':
        method = manifold.TSNE(n_components=2, init='pca', perplexity=param)
        Y = method.fit_transform(X)
    elif methodName == 'pca':
        pca = PCA(n_components=2)
        Y = pca.fit_transform(X)
        
    for label in sorted(set(labels)):
        idxs = labels == label
        ax.plot(Y[idxs, 0], Y[idxs, 1], 'o', label=label)
    ax.legend()

## 1.1. Interval -> Modalities

In [None]:
def tsne_interval_by_behaviour(datatype, selector, methodName, param, temporalStrat='raw', signCellsSelector=None):
    mc = MetricCalculatorNonUniform(serial=True)
    
    if signCellsSelector == None:
        signCellsSelector = {'None' : None}
        
    signCellsName, signCellsMouseDict = list(signCellsSelector.items())[0]

    fig, ax = plt.subplots(ncols=len(dataDB.mice), figsize=(6*len(dataDB.mice), 6))
    for iMouse, mousename in enumerate(sorted(dataDB.mice)):
        labelsBigLst = []
        dataBigLst = []
        for performance in ['Correct', 'Mistake']:
            for direction in ['L', 'R']:
                queryDict = {'datatype' : datatype, 'mousename' : mousename, 'performance' : performance, 'direction' : direction}

                dataLst = dataDB.get_data_from_selector(selector, queryDict)
                
                if signCellsMouseDict is not None:
                    idxCells = signCellsMouseDict[mousename]
                    dataLst = [d[idxCells] for d in dataLst]

                if temporalStrat == 'avg':
                    # Average over timesteps
                    dataArr = np.array([np.mean(d, axis=1) for d in dataLst if d.shape[1] > 0])
                elif temporalStrat == 'legendre':
                    # Estimate legendre basis functions, add them as extra dimensions
                    mc.set_data(dataLst)
                    dataArr = mc.metric3D("temporal_basis", "rp", metricSettings={"basisOrder": 5})
                    dataArr = numpy_merge_dimensions(dataArr, 1, 3)
                else:
                    # Consider timesteps and trials equivalent
                    dataArr = np.hstack(dataLst).T
                    
                dataBigLst += [dataArr]
                labelsBigLst += [str((direction, performance))]*len(dataArr)

#         print(mousename, [d.shape for d in dataBigLst])
        dataBigArr = np.vstack(dataBigLst)

        fit_color_data(ax[iMouse], dataBigArr, np.array(labelsBigLst), methodName, param)
        ax[iMouse].set_title(mousename)

    selectorType, selectorValue = list(selector.items())[0]
    prefix = '_'.join([datatype, selectorType, str(selectorValue), methodName, str(param), temporalStrat])
        
    fig.suptitle(prefix)
    plt.savefig(prefix + ".pdf")
    plt.show()

In [None]:
methodParamDict = {
    "pca"     : ("pca", None),
    "tsne10"    : ("tsne", 10),
    "tsne20"    : ("tsne", 20)
}

for interval in [6, 7, 8]:
    for datatype in ['raw', 'deconv']:
        for methodKey, (methodName, param) in methodParamDict.items():
            for temporalStrat in ['raw', 'avg', 'legendre']:
                for signCellsName, signCells in significantCellsSelector.items():
                    print("Significant Cells :", signCellsName, methodKey, temporalStrat, interval, datatype)

                    tsne_interval_by_behaviour(datatype, {'interval' : interval},
                                               methodName=methodName,
                                               param=param,
                                               temporalStrat=temporalStrat,
                                               signCellsSelector={signCellsName:signCells})

In [None]:
def test_average_L2_difference(dataDB, datatype, selector, condition, nTest=1000, pval=0.01, signCellsSelector=None):
    nMice = len(dataDB.mice)
    condValues = set(dataDB.metaDataFrames['behaviorStates'][condition])
    
    if signCellsSelector is None:
        signCellsSelector = {'None' : None}
    
    results = []
    
    for iMouse, mousename in enumerate(sorted(dataDB.mice)):
        means = []
        
        filterDict = list(signCellsSelector.values())[0]
        
        if filterDict is not None:
            nCells = len(filterDict[mousename])
        else:
            nCells = dataDB.get_nchannel(mousename, datatype)
        
        for condVal in condValues:
            queryDictCond = {"datatype" : datatype, "mousename" : mousename, condition : condVal}
            means += [metric_by_selector(dataDB, queryDictCond, "mean", "rp", selector, {}, channelFilter=filterDict)]

        # Avg over trials -> Norm Dist over neurons
        dist_func = lambda a, b: np.linalg.norm(np.mean(a, axis=0) - np.mean(b, axis=0))  #/ a.shape[1]

        distTrue = dist_func(means[0], means[1])
        settings = {"haveMeans" : True, "haveEffectSize" : True}
        results += [[
            mousename,
            *difference_test(dist_func, means[0], means[1], 1000, sampleFunction="permutation", settings=settings)
        ]]
        
    rezDF = pd.DataFrame(results, columns=["mousename", "pSmall", "pLarge", "effSize", "fTrue", "fRand"])
#     display(rezDF)
    return rezDF

In [None]:
rezDFDict = {}
for interval in [6, 7, 8]:
    for datatype in ['raw', 'deconv']:
        for condition in ['direction', 'performance']:
            for signCellsName, signCells in significantCellsSelectorDatatype[datatype].items():
                key = (interval, datatype, condition, signCellsName)
                print(key)

                rezDF = test_average_L2_difference(dataDB,
                                                   datatype,
                                                   {"interval": interval},
                                                   condition,
                                                   signCellsSelector={signCellsName:signCells})
                
                rezDFDict[key] = rezDF

In [None]:
with pd.option_context('display.max_rows', None, 'display.max_columns', None):
    display(merge_df_from_dict(rezDFDict, ['interval', 'datatype', 'condition', 'filter']))

## 1.2. Modality -> Intervals

In [None]:
def tsne_behaviour_by_interval(dataDB, datatype, direction, performance, methodName, methodParam, signCellsSelector=None):
    if signCellsSelector is None:
        signCellsSelector = {'None':None}
        
    signCellLabel, signCellDict = list(signCellsSelector.items())[0]
    
    fig, ax = plt.subplots(ncols=len(dataDB.mice), figsize=(6*len(dataDB.mice), 6))
    for iMouse, mousename in enumerate(sorted(dataDB.mice)):
        print("Doing mouse", mousename)
        
        queryDictMouse = {
            'datatype' : datatype,
            'mousename' : mousename,
            'performance' : performance,
            'direction' : direction}

        labelsBigLst = []
        dataBigLst = []

        for interval in range(6, 9):
            means = metric_by_selector(dataDB, queryDictMouse, "mean", "rp", {'interval' : interval}, {},
                                       channelFilter=signCellDict)

    #         mc.set_data(dataLst)
    #         dataArr = mc.metric3D("temporal_basis", "rp", metricSettings={"basisOrder": 5})
    #         dataArr = numpy_merge_dimensions(dataArr, 1, 3)

    #         means = np.hstack(dataLst).T  # concatenate trials and timepoints
            dataBigLst += [means]
            labelsBigLst += [str(interval+1)]*len(means)

        print([d.shape for d in dataBigLst])
        dataBigArr = np.vstack(dataBigLst)

        fit_color_data(ax[iMouse], dataBigArr, np.array(labelsBigLst), methodName, methodParam)
        ax[iMouse].set_title(mousename)

    plt.savefig("tsne_modality_"+datatype+'_'+signCellLabel+'_'+performance+'_'+direction+".pdf")
    plt.show()

In [None]:
for datatype in ['raw', 'deconv']:
    for performance in ['Correct', 'Mistake']:
        for direction in ['L', 'R']:
            for signCellsName, signCells in significantCellsSelectorDatatype[datatype].items():
                key = (datatype, performance, direction, signCellsName)
                print(key)
                
                tsne_behaviour_by_interval(dataDB, datatype, direction, performance, "tsne", methodParam=20,
                                           signCellsSelector={signCellsName:signCells})

# 5. Trajectory via PCA

Main Point:
* Select 1 interval
* compute 2D PCA
* plot trajectories by trial
* color by modality ($LR\otimes CM$)

Dynamics Strategy:
0. Naive-0: Plot indivitual PCA components as function of time
1. Naive-1: Exact PCA coordinates
    * Problem: that is not a trajectory, just blob that oscillates around zero
2. Adept-1: Add intertia by time-accumulation
3. Adept-2: Add intertia by gaussian-filtering
    
TASKS
* **TODO**: Average over trials for 1 mouse, only 1 session
* **TODO**: How is average and cumulative possible simultaneously
* **TODO**: Add curve distance as function of timestep, shuffle test

In [None]:
def pca_plots_wrapper(datatype, selector, paramDict, signCellsSelector=None):
    if signCellsSelector is None:
        signCellsSelector = {'None':None}
        
    signCellLabel, signCellDict = list(signCellsSelector.items())[0]
    selectorKey, selectorVal = list(selector.items())[0]
    
    # Create all possible parameter combinations
    # Exclude trial-averaging when no crop strategy employed as it makes no sense to average trials of different length
    paramProdDF = outer_product_df(paramDict)
    badrows = paramProdDF[(paramProdDF['trialStrategy'] == 'avg') & (paramProdDF['cropStrategy'].isnull())]
    paramProdDF = paramProdDF.drop(badrows.index)
    paramProdDF = paramProdDF.reset_index(drop=True)

    def makefigure(nrows, ncols, title):
        fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(4*ncols, 4*nrows), tight_layout=True)
        fig.suptitle(title)
        return fig, ax

    figAvg, axAvg = makefigure(1, len(dataDB.mice), "Trial-average")
    figTime1, axTime1 = makefigure(1, len(dataDB.mice), "PCA1 vs Timestep")
    figTime2, axTime2 = makefigure(1, len(dataDB.mice), "PCA2 vs Timestep")
    figsPCA = [makefigure(1, len(dataDB.mice), "TMP") for i in range(len(paramProdDF))]
    figsDist = [makefigure(1, len(dataDB.mice), "TMP") for i in range(len(paramProdDF))]

    for iMouse, mousename in enumerate(sorted(dataDB.mice)):
        print("Doing mouse:", mousename)
        
        if signCellDict is not None:
            channelFilterMouse = signCellDict[mousename]
        else:
            channelFilterMouse = None
        
        queryDict = {'datatype' : datatype, 'mousename' : mousename}
        pp = PCAPlots(dataDB, selector, queryDict, channelFilterMouse=channelFilterMouse)
        pp.set_stretch_timesteps(100)
        pp.plot_time_avg_scatter(axAvg[iMouse])
        pp.plot_pca_vs_time(axTime1[iMouse], 0)
        pp.plot_pca_vs_time(axTime2[iMouse], 1)

        for idx, row in paramProdDF.iterrows():
            print("--Doing strategy:", dict(row))
        
            # PCA
            fig, ax = figsPCA[idx]
            if iMouse == 0:
                fig.suptitle("PCA :: " + str(dict(row)))

            pp.plot_pca(ax[iMouse], dict(row))
                
                
            # DIST
            fig, ax = figsDist[idx]
            if iMouse == 0:
                fig.suptitle("Dist :: " + str(dict(row)))

            pp.plot_distances(ax[iMouse], dict(row))
            

    # Saving stuff    
    fnamePrefix = datatype+"_cellfilter_" + signCellLabel + "_" +selectorKey+"_"+str(selectorVal)
    
    plt.figure(figAvg.number)
    plt.savefig("Trial_average_"+fnamePrefix+".pdf")
    
    plt.figure(figTime1.number)
    plt.savefig("PCA1_vs_time_"+fnamePrefix+".pdf")
    
    plt.figure(figTime2.number)
    plt.savefig("PCA2_vs_time_"+fnamePrefix+".pdf")
    
    for idx, row in paramProdDF.iterrows():
        rowKey = '_'.join([k+'_'+str(v) for k,v in dict(row).items()])
        
        plt.figure(figsPCA[idx][0].number)
        plt.savefig("PCA1_vs_PCA2_"+fnamePrefix+'_'+rowKey+".pdf")

        plt.figure(figsDist[idx][0].number)
        plt.savefig("manifold_dist_vs_time_"+fnamePrefix+'_'+rowKey+".pdf")
            
    plt.show()

In [None]:
paramDict = {
    "cropStrategy" :["stretch", "cropmin"], #[None, "cropmin", "stretch"],
    "trialStrategy" : [None], #[None, "concat", "avg"],
    "accStrategy" : [None], #[None, "cumulative", "gaussfilter"]
}

for datatype in ['raw', 'deconv']:
    for interval in [6,7,8]:
        for signCellsName, signCells in significantCellsSelectorDatatype[datatype].items():
            if signCellsName != 'None':
            
                key = (datatype, interval, signCellsName)
                print(key)

                # pca_plots_wrapper('deconv', {"phase" : "Maintenance"}, paramDict)
                pca_plots_wrapper(datatype,
                                  {"interval" : interval},
                                  paramDict,
                                  signCellsSelector={signCellsName:signCells})

In [None]:
x = np.linspace(0, 1, 100)
y = np.sin(5*x)
z = x

fig, ax = plt.subplots(figsize=(4,4))
plots_lib.plot_coloured_1D(ax, x, y, y**2)
plt.show()