In [None]:
# Standard libraries
from copy import deepcopy

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interactive, IntProgress
from IPython.display import display
from scipy.stats import combine_pvalues

# User libraries
# from mesostat.stat.machinelearning import binary_classifier
from mesostat.utils.qt_helper import gui_fnames, gui_fpath
from mesostat.utils.pandas_helper import merge_df_from_dict
from mesostat.utils.arrays import numpy_merge_dimensions
from mesostat.visualization.mpl_matrix import imshow
from mesostat.visualization.mpl_violin import violins_labeled
from mesostat.utils.signals.resample import resample_stretch
from mesostat.stat.testing.quantity_test import test_quantity, rstest_twosided

from pfc_mem_proj.lib.data_db import BehaviouralNeuronalDatabase
import pfc_mem_proj.lib.plots_lib as plots_lib
import pfc_mem_proj.lib.table_lib as table_lib
import pfc_mem_proj.lib.plots_pca as plots_pca
from pfc_mem_proj.lib.metric_wrapper import metric_by_selector
from pfc_mem_proj.lib.extra_metrics import num_non_zero_std, num_sample
from pfc_mem_proj.lib.significant_cells_lib import SignificantCells
from pfc_mem_proj.lib.excel_export import write_excel_2D

%load_ext autoreload
%autoreload 2

In [None]:
# tmp_path = root_path_data if 'root_path_data' in locals() else "./"
params = {}
#params['root_path_data']  = gui_fpath("Path to data files", "./")
params['root_path_dff'] = '/media/alyosha/DataNew/TE_data/mariadata/dff/'
params['root_path_deconv'] = '/media/alyosha/DataNew/TE_data/mariadata/deconv/'

In [None]:
dataDB = BehaviouralNeuronalDatabase(params)

In [None]:
dataDB.read_neuro_files()

In [None]:
dataDB.read_behavior_files()

In [None]:
significantCellsSelectorDatatype = {}

for datatype in ['deconv']: #['raw', 'deconv']:
    signCellsMaintenance = SignificantCells('significant_cells_'+datatype+'_mt.h5').get_cells_by_mouse()
    signCellsReward = SignificantCells('significant_cells_'+datatype+'_enc_reward.h5').get_cells_by_mouse()

    significantCellsSelectorDatatype[datatype] = {
        'None' : None,
        'Maintenance' : signCellsMaintenance,
        'Reward' : signCellsReward
    }

# 1. Neuron-Time-Average

**Goal**: Attempt to predict L/R and C/M from mean activity over trial

**VERY IMPORTANT**: When we compare significant cells, we must compare exactly the same cells in each condition. So, we must select all cells significant within that interval/phase

In [None]:
def test_mean_time_and_neurons(dataDB, datatype, selector, condition, queryDict,
                               signCellsMouseDict=None, havePlot=True):
    nMice = len(dataDB.mice)
    condValues = set(dataDB.metaDataFrames['behaviorStates'][condition])
    testResults = []
    
    if havePlot:
        fig, ax = plt.subplots(ncols=nMice, figsize=(4*nMice, 4))
    for iMouse, mousename in enumerate(sorted(dataDB.mice)):
        means = []
        
        if signCellsMouseDict is not None:
            channelFilter = signCellsMouseDict
            nChannels = len(signCellsMouseDict[mousename])
        else:
            channelFilter = None
            nChannels = dataDB.get_nchannel(mousename, datatype)
            
        for condVal in condValues:
            queryDictCond = {"datatype" : datatype, "mousename" : mousename,
                             condition : condVal}
            queryDictCond = {**queryDict, **queryDictCond}
            means += [metric_by_selector(dataDB, queryDictCond, "mean", "r", selector, {}, channelFilter=signCellsMouseDict)]

        meanVals = [np.mean(mu) for mu in means]
        nTrialsCond = tuple(len(mu) for mu in means)
        negLogPVal = -np.log10(rstest_twosided(*means)[1])
        
                
        testResults += [[mousename, *meanVals, nTrialsCond, nChannels, negLogPVal]]
        if havePlot:
            ax[iMouse].violinplot(means)
            ax[iMouse].set_title(mousename)
        
    dfRez = pd.DataFrame(testResults, columns=["mouse", *condValues, "nTrials", "nNeurons", "-log10(pval)"])
    
    pVal = 10**(-dfRez['-log10(pval)'])
    print("-- Combined p-values", combine_pvalues(pVal)[1])
    display(dfRez)
    if havePlot:
        plt.show()
    
    return dfRez

In [None]:

rezDict = {}
for semiphase in ['M1', 'M2']:
    for datatype in ['raw', 'deconv']:
        signCells=None
#         for signCellsName, signCells in significantCellsSelectorDatatype[datatype].items():
#             print(condition, useSignificant)
        selector = {"semiphase" : semiphase}
        
        keys = (datatype, semiphase, 'direction', 'Correct')
        print(keys)
        dfRez = test_mean_time_and_neurons(dataDB, datatype, selector, 'direction',
                                           {'performance': 'Correct'},
                                           signCellsMouseDict=signCells,
                                           havePlot=False)
        rezDict[keys] = dfRez 
        
        keys = (datatype, semiphase, 'performance', 'L')
        print(keys)
        dfRez = test_mean_time_and_neurons(dataDB, datatype, selector, 'performance',
                                   {'direction': 'L'},
                                   signCellsMouseDict=signCells,
                                   havePlot=False)
        rezDict[keys] = dfRez 
        
        keys = (datatype, semiphase, 'performance', 'R')
        print(keys)
        dfRez = test_mean_time_and_neurons(dataDB, datatype, selector, 'performance',
                                   {'direction': 'R'},
                                   signCellsMouseDict=signCells,
                                   havePlot=False)
        rezDict[keys] = dfRez 
        
                       

In [None]:
with pd.option_context('display.max_rows', None, 'display.max_columns', None):
    display(merge_df_from_dict(rezDict, ['datatype', 'semiphase', 'condition', 'query']))

# 1.1 Neuron average

1. Compute time-average
2. Resample all trials to fixed length (e.g. 100pt)
3. Test time-wise

In [None]:
def stack_dicts(dictList, keys):
    return {k : [d[k] for d in dictList] for k in keys}


def test_mean_neurons(dataDB, datatype, mousename, selector, condition, queryDict, nTimeTrg=100):
    nMice = len(dataDB.mice)
    condValues = set(dataDB.metaDataFrames['behaviorStates'][condition])
    
    nChannels = dataDB.get_nchannel(mousename, datatype)

    dataSRCond = []
    dataMuDict = {}
    dataStdDict = {}
    
    for condVal in condValues:
        queryDictCond = {"datatype": datatype, "mousename": mousename,
                         condition: condVal}
        queryDictCond = {**queryDictCond, **queryDict}

        dataLstPS = dataDB.get_data_from_selector(selector, queryDictCond)
        dataLstS = [np.mean(d, axis=0) for d in dataLstPS]
        dataSR = np.array([resample_stretch(d, nTimeTrg) for d in dataLstS]).T
        dataSRCond += [dataSR]

        dataMuDict[condVal] = np.mean(dataSR, axis=1)
        dataStdDict[condVal] = np.std(dataSR, axis=1) / np.sqrt(dataSR.shape[1])

    pVals = [rstest_twosided(mu1, mu2)[1] for mu1, mu2 in zip(*dataSRCond)]
    nLogPvals = -np.log10(pVals)
    
    return dataMuDict, dataStdDict, nLogPvals


def test_plot_mean_neurons(dataDB, datatype, selector, condParamLst, nTimeTrg=100):
    nCond = len(condParamLst)
    nMice = len(dataDB.mice)
    selectorLst = list(selector.values())
    
    plotSuffixLst = [datatype] + selectorLst
    plotSuffix = '_'.join([str(s) for s in plotSuffixLst])
    
    x = np.arange(nTimeTrg)
    
    figByCond, axByCond = plt.subplots(nrows=2, ncols=nCond, figsize=(4*nCond, 8))
    figByMouse, axByMouse = plt.subplots(ncols=nMice, figsize=(4*nMice, 4))
    
    excelWriter = pd.ExcelWriter('excel_out/lrcm_htest_' + plotSuffix + '.xlsx')
    
    for iCond, (condition, queryDict) in enumerate(condParamLst):
        queryLst = list(queryDict.values())
        plotLabel = condition + '_for_' + '_'.join(queryLst)
        
        condValues = set(dataDB.metaDataFrames['behaviorStates'][condition])
        
        dataMuDictLst = []
        dataStdDictLst = []
        nLogPvalsLst = []
        
        for iMouse, mousename in enumerate(sorted(dataDB.mice)):
            dataMuDict, dataStdDict, nLogPvals = test_mean_neurons(dataDB, datatype, mousename, selector,
                                                                   condition, queryDict, nTimeTrg=nTimeTrg)
            # Clip too significant changes to 3* for visual comparability
            nLogPvalsClipped = np.clip(nLogPvals, None, 3)
            
            dataMuDictLst += [dataMuDict]
            dataStdDictLst += [dataStdDict]
            nLogPvalsLst += [nLogPvals]
            
            axByCond[1, iCond].plot(nLogPvalsClipped, label=mousename)
            axByMouse[iMouse].plot(nLogPvalsClipped, label=plotLabel)
            
            axByMouse[iMouse].set_xlabel('stretched time')
            axByMouse[iMouse].set_title(mousename)
            
        dataMuDictStack = stack_dicts(dataMuDictLst, condValues)
        dataStdDictStack = stack_dicts(dataStdDictLst, condValues)
            
        for condVal in condValues:
            mu = np.mean(dataMuDictStack[condVal], axis=0)
            std = np.linalg.norm(dataStdDictStack[condVal], axis=0) / np.sqrt(nMice)
            axByCond[0, iCond].plot(x, mu, label=condVal)
            axByCond[0, iCond].fill_between(x, mu-std, mu+std, alpha=0.1)
            
        axByCond[1, iCond].set_ylim(None, 3.05)
        axByCond[1, iCond].axhline(y=2, linestyle='--', color='pink')
        axByCond[0, iCond].set_title(plotLabel)
        axByCond[1, iCond].set_xlabel('stretched time')
        axByCond[0, iCond].legend()
        axByCond[1, iCond].legend()
        
        # Write to excel
        write_excel_2D(np.array(nLogPvalsLst).T, excelWriter, plotLabel,
                   colnames=np.array(sorted(dataDB.mice)))
    
    axByCond[0, 0].set_ylabel('Mouse-and-cell-average activity')
    axByCond[1, 0].set_ylabel('-log10(pval)')
    axByMouse[0].set_ylabel('-log10(pval)')
    
    for iMouse, mousename in enumerate(sorted(dataDB.mice)):
        axByMouse[iMouse].set_ylim(None, 3.05)
        axByMouse[iMouse].axhline(y=2, linestyle='--', color='pink')
        axByMouse[iMouse].legend()
    

#     axByCond[1].legend()
    
    figByCond.savefig('test_by_condition_mean_timestep_' + plotSuffix + '.pdf')
    plt.close()
    figByMouse.savefig('test_by_mouse_mean_timestep_' + plotSuffix + '.pdf')
    plt.close()
    
    excelWriter.close()

In [None]:
# selector = {"phase" : "Maintenance"}

condParamLst = [
    ["performance", {"direction": "L"}],
    ["performance", {"direction": "R"}],
    ["direction", {"performance": "Correct"}]
]

rezDict = {}
for semiphase in ['M1', 'M2']:
    selector = {"semiphase" : semiphase}
    for datatype in ['raw', 'deconv']:
        print(selector, datatype)
        test_plot_mean_neurons(dataDB, datatype, selector, condParamLst)

# 2 Time-Average

**Goal**: Attempt to predict LR/CM from average activities of individual cells.
* Count predictive cells in each mouse, phase/interval

In [None]:
def test_mean_time(dataDB, datatype, selector, condition, nTest=1000, pval=0.01, signCellsMouseDict=None):
    nMice = len(dataDB.mice)
    condValues = set(dataDB.metaDataFrames['behaviorStates'][condition])
    mouseResults = []
    
    fig1, ax1 = plt.subplots(ncols=nMice, figsize=(4*nMice, 4), tight_layout=True)
    for iMouse, mousename in enumerate(sorted(dataDB.mice)):
        means = []
        
        if signCellsMouseDict is not None:
            channelFilter = signCellsMouseDict
            nCells = len(signCellsMouseDict[mousename])
        else:
            channelFilter = None
            nCells = dataDB.get_nchannel(mousename, datatype)
        
        for condVal in condValues:
            queryDictCond = {"datatype" : datatype, "mousename" : mousename, condition : condVal}
            means += [metric_by_selector(dataDB, queryDictCond, "mean", "pr", selector, {}, channelFilter=channelFilter)]
            
        pValByCell, nCellSignificant, negLogPValPop = test_quantity(means[0], means[1], pval)
        mouseResults += [[mousename, nCellSignificant, nCells, np.round(negLogPValPop, 2)]]
        
        ax1[iMouse].plot(sorted(-np.log10(pValByCell)))
        ax1[iMouse].axhline(y=2, linestyle="--", color='r')
        ax1[iMouse].set_xlabel("cell index, sorted")
        ax1[iMouse].set_ylabel("-log10(pVal)")
        ax1[iMouse].set_title(mousename)
    
    rezDf = pd.DataFrame(mouseResults, columns=["mouse", "nCellSignificant", "nCellTot", "-log10(pval)"])
    display(rezDf)
        
    plt.show()
    
    return rezDf

In [None]:
# selector = {"phase" : "Maintenance"}

rezDFDict = {}
for interval in range(6, 9):
    for datatype in ['deconv']: #['raw', 'deconv']:
        for condition in ["performance", "direction"]:
            for signCellsName, signCells in significantCellsSelectorDatatype[datatype].items():
    #             print(condition, useSignificant)
                keys = (interval, datatype, condition, signCellsName)
                selector = {"interval" : interval}
                print(keys)
                rezDF = test_mean_time(dataDB, datatype, selector, condition, signCellsMouseDict=signCells)
                rezDFDict[keys] = rezDF

In [None]:
with pd.option_context('display.max_rows', None, 'display.max_columns', None):
    display(merge_df_from_dict(rezDFDict, ['interval', 'datatype', 'condition', 'filter']))

## 2.1 LRCM discrimination by phase/interval

In [None]:
def test_ncells_by_interval(dataDB, queryDict, datatype, condition, phaseType, phaseTypeRange, excelWriter,
                            nTest=1000, pval=0.01, plotYLim=None):
    nMice = len(dataDB.mice)
    condValues = set(dataDB.metaDataFrames['behaviorStates'][condition])
    mouseResults = []
    
    xDummy = np.arange(len(phaseTypeRange))
    
    fig, ax = plt.subplots(figsize=(10,10))
    for iMouse, mousename in enumerate(sorted(dataDB.mice)):
        print("Calculating mouse", mousename, 'for', datatype, phaseType, condition)
        
        freqCellSignificant = []
        for phaseTypeVal in phaseTypeRange:
            selector = {phaseType : phaseTypeVal}
        
            means = []
            for condVal in condValues:        
                queryDictCond = {**queryDict, **{"datatype" : datatype, "mousename" : mousename, condition : condVal}}
                means += [metric_by_selector(dataDB, queryDictCond, "mean", "pr", selector, {})]

            nCells = means[0].shape[0]
            _, nCellSignificant, _ = test_quantity(means[0], means[1], pval)
            freqCellSignificant += [nCellSignificant / nCells]
            
        ax.plot(xDummy, freqCellSignificant, label=mousename)
        ax.set_xticks(xDummy)
        ax.set_xticklabels(phaseTypeRange)
        
        mouseResults += [freqCellSignificant]

    suffix = '_'.join([phaseType, datatype, condition, list(queryDict.values())[0]])
        
#     ax.axvline(x=5.5, linestyle='--', color='pink')
#     ax.axvline(x=8.5, linestyle='--', color='pink')
    ax.legend()
    if plotYLim is not None:
        ax.set_ylim(plotYLim)
    
    ax.set_xlabel(phaseType)
    ax.set_ylabel('Significant cell fraction')
    fig.savefig('quantity_significant_cells_'+suffix+'.pdf')
    plt.close()
    
    # Write results to excel
    write_excel_2D(np.array(mouseResults), excelWriter, suffix,
                   rownames=np.array(sorted(dataDB.mice)),
                   colnames=np.array(phaseTypeRange)
                  )

In [None]:
excelWriter = pd.ExcelWriter('excel_out/lrcm_tests.xlsx')

# queryDict = {"datatype" : "raw", "direction" : "R"}
# queryDict = {"datatype" : "raw", "performance" : "Correct"}

for datatype in ['deconv']: #["raw", "deconv"]:
    for phaseType in ['phase', 'semiphase']:#["phase", "interval", "semiphase"]:
#         ylim = [0, 0.4] if phaseType == 'phase' else [0, 0.8]
        ylim=[0, 0.5]
            
        for condition in ["performance", "direction"]:
            if condition == "performance":
                secondCond = "direction"
                secondCondVals = ["L", "R"]
            else:
                secondCond = "performance"
                secondCondVals = ["Correct", "Mistake"]
                
            for condVal in secondCondVals:
                queryDict = {secondCond : condVal}
                ranges = dataDB.get_phasetype_keys(phaseType, "Mistake", haveWaiting=False)
                test_ncells_by_interval(dataDB, queryDict, datatype, condition, phaseType,
                                        ranges, excelWriter, plotYLim=ylim)
                
excelWriter.close()

## 2.3 Significant Cell Confusion matrices

In [None]:
def text_different_one(data2D, i):
    dataThis = data2D[i]
    dataOther = np.hstack(data2D[:i] + data2D[i+1:])
    T, p = rstest_twosided(dataThis, dataOther)
    return -np.log10(p)


def text_different(data2D):
    return np.array([text_different_one(data2D, i) for i in range(len(data2D))])


# Calculate confusion matrix
def significance_confusion_matrix(sign2D):
    nCell, nPhase = sign2D.shape  # A boolean array reporting if a given cell is significant in a given phase
    confMat = np.zeros((nPhase, nPhase))
    for i in range(nPhase):
        for j in range(nPhase):
            confMat[i][j] = np.sum(np.logical_and(sign2D[:, i], sign2D[:, j]))
    return confMat


def confusion_matrices_LRCM(queryDict, selector):
    settings = {"zscoreChannel" : False, "serial" : True, "metricSettings" : {}}

    fig, ax = plt.subplots(ncols=len(dataDB.mice), figsize=(4*len(dataDB.mice), 4), tight_layout=True)
    
    for iMouse, mousename in enumerate(sorted(dataDB.mice)):
        print('doing mouse', mousename)

        keys = []
        rezLst = []
        for performance in ['Correct', 'Mistake']:
            for direction in ['L', 'R']:
                keys += [str((performance, direction))]
                queryDictThis = {**queryDict, **{'mousename' : mousename, 'performance' : performance, 'direction' : direction}}
                rezLst += [metric_by_selector(dataDB, queryDictThis, 'mean', 'pr', selector, settings)]
                
        nChannel = len(rezLst[0])
        pVals2D = np.array([text_different([rez[iCh] for rez in rezLst]) for iCh in range(nChannel)])

        # Calculate confusion matrix
        sign2D = pVals2D > 2
        confMat = significance_confusion_matrix(sign2D)

        # Plot confusion matrix
        imshow(fig, ax[iMouse], confMat, limits=[0, nChannel], title=mousename, haveColorBar=True, cmap='jet',
               haveTicks=True, xTicks=keys, yTicks=keys)

    fig.savefig('interval_'+ str(selector["interval"]) +'_significant_cells_confusion.pdf')
    #plt.show()
    plt.close()

In [None]:
for iInterv in range(5, 10):
    confusion_matrices_LRCM({"datatype" : "deconv"}, {"interval" : iInterv})

# 3. Scalar metric tests

**Goal**: Evaluate predictive power for several metrics
* Stretch hypothesis: Legendre Basis (try several bases individually up to 6)
* Synchronization hypothesis: AvgCorr, H

## Table

In [None]:
%%time
dataDB.verbose = False

#settings = {"serial" : True, "metricSettings" : {"metric" : num_non_zero_std}}
settings = {"serial" : True, "metricSettings" : {"max_lag" : 1}}
sweepDict = {
    "datatype": ["deconv"],
    "performance": ["Correct", "Mistake", "All"],
#     "direction": ["L", "R", "All"]
}
# selector = {"phase" : "Maintenance"}
selector = {"interval" : 9}

table_lib.table_discriminate_behavior(dataDB, selector, "direction",
                                      sweepDict,
                                      "mean",
                                      trgDimOrder="r",
                                      settings=settings,
                                      multiplexKey="mousename",
                                      channelFilter=None)

## Violins

In [None]:
def binary_test_phase(dataDB, queryDict, condition, selector, metricName, settings):
    condValues = list(set(dataDB.metaDataFrames['behaviorStates'][condition]))
    
    rezLst = []
    for condVal in condValues:
        queryDictCond = {**queryDict, **{condition : condVal}}
        rezLst += [metric_by_selector(dataDB, queryDictCond, metricName, "r", selector, {})]
        
    fig, ax = plt.subplots()
    violins_labeled(ax, rezLst, condValues, condition, metricName,
                                   joinMeans=True, haveLog=False, sigTestPairs=[(0,1)])
    
    
#settings = {"serial" : True, "metricSettings" : {"metric" : num_sample}}
settings = {"serial" : True, "metricSettings" : {}}
queryDict = {"datatype" : "deconv", "direction" : "L"}

binary_test_phase(dataDB, queryDict, "performance", {"interval" : 9}, "mean", settings)
#binary_test_phase(dataDB, queryDict, "direction", ["L", "R"], "Maintenance", settings)