In [None]:
# Standard libraries
from copy import deepcopy

import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interactive, IntProgress
from IPython.display import display
from scipy.stats import mannwhitneyu, wilcoxon, binom_test, combine_pvalues

# User libraries
from mesostat.utils.qt_helper import gui_fnames, gui_fpath
from mesostat.utils.arrays import numpy_merge_dimensions
from mesostat.utils.matrix import offdiag_1D
from mesostat.visualization.mpl_matrix import imshow
from mesostat.visualization.mpl_colors import custom_grad_cmap

# Local libraries
from pfc_mem_proj.lib.data_db import BehaviouralNeuronalDatabase
import pfc_mem_proj.lib.table_lib as table_lib
from pfc_mem_proj.lib.metric_wrapper import metric_by_selector_all, metric_by_selector
from pfc_mem_proj.lib.extra_metrics import num_non_zero_std, num_sample
from pfc_mem_proj.lib.significant_cells_lib import SignificantCells
import pfc_mem_proj.lib.analysis.phases_intervals as phases_intervals

%load_ext autoreload
%autoreload 2

In [None]:
# tmp_path = root_path_data if 'root_path_data' in locals() else "./"
params = {}
#params['root_path_data']  = gui_fpath("Path to data files", "./")
params['root_path_dff'] = '/media/alyosha/DataNew/TE_data/mariadata/dff/'
params['root_path_deconv'] = '/media/alyosha/DataNew/TE_data/mariadata/deconv/'

In [None]:
dataDB = BehaviouralNeuronalDatabase(params)

In [None]:
dataDB.read_neuro_files()

In [None]:
dataDB.read_behavior_files()

# 1. Neuron-Time-Average

**Goal**: Explore phase and interval specificity of bulk signal

**Note**: Already done by Maria, no need to repeat

In [None]:
# cmapConfusion = gen_cmap_3color(
#     np.array([255, 255, 255]),
#     np.array([168, 217, 49]),
#     np.array([255, 0, 0])
# )

cmapConfusion = custom_grad_cmap(
    np.array([
        [0,0,0],
        [255,198,126],
        [233,79,25]
    ])
)

cmapSignificance = custom_grad_cmap(
    np.array([
        [255, 255, 255],
        [233, 79, 25]
    ])
)

# 2 Time-Average

**Goal**: Explore phase and interval specificity of individual neurons

## 2.1. Average activity

Compute average signal for a matrix [Channel x Interval], plot

In [None]:
for datatype in ['raw', 'deconv', 'zscore']:
    for phasetype in ['semiphase']:#['interval', 'phase', 'semiphase']:
        phases_intervals.plot_avg_firing_rate_by_neuron(dataDB, datatype, phasetype, haveWaiting=False, cmap=cmapConfusion)

# 2.2 Active cells

* For each neuron, determine baseline above which we consider there to be some activity
* For each neuron and each interval, count fraction of trials in which cell active

**Complications**:
* How exactly to define a "sufficiently active" cell?
* How to compute a firing threshold? Is constant ok, or need be cell-specific?
* Need to control for different interval duration across trials?
* Need to control for different interval duration across intervals?

**NOTE**
* Run section 2.3 before this section, as it extracts the labels of significant cells

In [None]:
for datatype in ['deconv']:
    for phasetype in ['interval', 'phase', 'semiphase']:
        phases_intervals.plot_count_active_trials_by_neuron(dataDB, datatype, phasetype, 0.18, 0.2, haveWaiting=False)

Questions of interest:
* [+] Is phase-specific activity explained by
    - frequency active?
    - magnitude when active?
    - both?
* Why are there more active cells during MT vs ENC/RET?
* TODO:
    - Plot avg activity only for active trials
    - Color neurons by phase-specificity
    - Perform some kind of test to numerically explore

In [None]:
%matplotlib inline
phases_intervals.plot_activity_vs_active_frequency(dataDB, 'deconv', 'phase', 'm060', 'Correct', thrAct=0.2, haveWaiting=False)

## 2.3 Significant activity
* For each cell and each interval, test if activity in that interval significantly higher than on average for that cell
* 

In [None]:
for datatype in ['deconv']:#['deconv', 'raw']:
    for phasetype in ['phase']: #['semiphase', 'phase', 'interval']:
        print(datatype, phasetype)
        phases_intervals.plot_significant_firing_rate_by_neuron(dataDB, datatype, phasetype,
                                                                confThr=0.01, haveWaiting=False,
                                                                cmapConfusion = cmapConfusion)

## 2.4 Store significant neurons
For each interval, store neurons significantly active in that interval

In [None]:
# Testing each phase vs remainder of the trial, storing maintenance vs encoding+retrieval
# for datatype in ['raw', 'deconv']:
#     plot_save_significantly_firing_neurons(dataDB, datatype,
#                                            'phase',
#                                            [[1], [0, 2]],
#                                            ['mt', 'enc_ret'],
#                                            confThr=0.01,
#                                            haveAll=True)
    
# phasePairIdxs = [[[0], [1]], [[0], [2]], [[1], [2]]]
# phasePairNames = [['enc', 'mt'], ['enc', 'ret'], ['mt', 'ret']]

phasePairIdxs = [[[0], [1]], [[1], [2]]]
phasePairNames = [['enc', 'mt'], ['mt', 'ret']]

# phasePairIdxs = [[[0], [1]]]
# phasePairNames = [['enc', 'mt']]
    
for datatype in ['deconv']:#['raw', 'deconv']:
    for pairName, pairIdxs in zip(phasePairNames, phasePairIdxs):    
        phases_intervals.plot_save_significantly_firing_neurons(dataDB, datatype, 'phase', pairIdxs, pairName,
                                                                confThr=0.01, haveAll=True)

In [None]:
# Get indices of significant cells
from mesostat.utils.pandas_helper import pd_query
df = pd.read_hdf('significant_cells_deconv_enc.h5', 'df')
# row = pd_query(df, {'mousename' : 'm060', 'performance' : 'All', 'direction' : 'All'})
# np.array(row['cells'])
df

In [None]:
for datatype in ['raw', 'deconv']:
    phases_intervals.plot_save_significantly_firing_neurons(dataDB, datatype, 'interval',  [[0], [1]],
                                                            ['enc_base', 'enc_reward'],
                                                            confThr=0.01, ranges=[2, 4], haveAll=True)

## 2.5 Load and test significant cells

In [None]:
significantCellsSelectorDatatype = {}

for datatype in ['raw', 'deconv']:
    signCellsMaintenance = SignificantCells('significant_cells_'+datatype+'_mt.h5').get_cells_by_mouse()
    signCellsReward = SignificantCells('significant_cells_'+datatype+'_enc_reward.h5').get_cells_by_mouse()

    significantCellsSelectorDatatype[datatype] = {
        'None' : None,
        'Maintenance' : signCellsMaintenance,
        'Reward' : signCellsReward
    }

## 2.6 Are 'Maintenance Cells' more active than 'Encoding Cells'?

1. For each cell, compute phase-avg over time
2. For each cell, for each trial, compute ratio of ENC vs MT
3. 

$$R_{i,j,ENC} = \int_{ENC} R_{i,j}(t)dt$$
$$R_{i,j,MT} = \int_{MT} R_{i,j}(t)dt$$
$$\phi_{i,j} = \frac{R_{i,j,ENC}}{R_{i,j,MT}}$$
$$\bar \phi_{i} = \frac{1}{N_{trial}} \sum_j \phi_{i,j}$$


In [None]:
phases_intervals.plot_ratio_enc_mt(dataDB, 'deconv')

# 3.1 Table - Discriminate Phases by Metric

In [None]:
%%time
dataDB.verbose = False

phases = ['Encoding', 'Maintenance', 'Retrieval']
#settings = {"serial" : True, "metricSettings" : {"metric" : num_non_zero_std}}
settings = {"serial" : True, "metricSettings" :{"max_lag" : 1}}
sweepDict = {
    #"mousename" : sorted(list(dataDB.mice)),
    "datatype": ["deconv"],#, "raw", "high", "deconv"],
    "performance": ["Correct", "Mistake", "All"],
    "direction": ["L", "R", "All"]
}

table_lib.table_discriminate_time(dataDB, sweepDict,
                                         {"phase" : phases},
                                         "mean",
                                         trgDimOrder="r",
                                         settings=settings, multiplexKey="mousename")

# 3.2 Violins - Metric by Phase and Interval

In [None]:
settings = {"serial" : True, "metricSettings" : {"max_lag" : 1}}
#settings = {"serial" : True, "metricSettings" : {"metric" : num_non_zero_std}}

        
for datatype in ['raw', 'deconv']:
    for phaseType in ['semiphase']:#['interval', 'semiphase']:
        #for signCellsName, signCells in significantCellsSelectorDatatype[datatype].items():
#         print("datatype", datatype, "Significant Cells :", signCellsName)
        print("datatype", datatype)
#         phases_intervals.plot_violins_by_phase(dataDB, datatype, phaseType, "mean", settings,
#                                                haveWaiting=False, signCellsSelector={signCellsName:signCells})
        phases_intervals.plot_violins_by_phase(dataDB, datatype, phaseType, "mean", settings,
                                               haveWaiting=False, signCellsSelector=None)