In [None]:
# Standard libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from ipywidgets import IntProgress
from IPython.display import display

# Append base directory
import os,sys,inspect
rootname = "pub-2020-exploratory-analysis"
thispath = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
rootpath = os.path.join(thispath[:thispath.index(rootname)], rootname)
sys.path.append(rootpath)
print("Appended root directory", rootpath)

from mesostat.metric.metric import MetricCalculator
from mesostat.utils.qt_helper import gui_fname, gui_fnames, gui_fpath
from mesostat.utils.hdf5_io import DataStorage

from lib.sych.data_fc_db_raw import DataFCDatabase
import lib.analysis.pid as pid

%load_ext autoreload
%autoreload 2

In [None]:
# tmp_path = root_path_data if 'root_path_data' in locals() else "./"
params = {}
# params['root_path_data'] = './'
params['root_path_data'] = '/media/alyosha/Data/TE_data/yarodata/sych_preprocessed'
# params['root_path_data'] = '/media/aleksejs/DataHDD/work/data/yaro/neuronal-raw-pooled'
# params['root_path_data'] = gui_fpath('h5path', './')

In [None]:
dataDB = DataFCDatabase(params)

In [None]:
dataDB.mice

In [None]:
h5outname = 'sych_result_higher_order_df.h5'

In [None]:
mc = MetricCalculator(serial=False, verbose=False, nCore=4)

# TODO
    
Hypothesis - Chain inhibition should increase synergy
* Cpu -> iGP/GP/eGP -> VM/VL

Performance-depencence
* Session-wise changes of redundancy/synergy as function of performance
* Movement-correlations of synergy/redundancy
    - Lick
    - Integral movement

# Hypotheses

In [None]:
hypothesesDict = {
    # Feedforwards Prefrontal
    "H1_TEX"  : ("TEX", ['M1_l', 'S1_bf'], ['PrL', 'LO', 'VO', 'M2', 'Cg1']),
    "H1a_TEX" : ("TEX", ['S1_bf', 'S2'],   ['PrL', 'LO', 'VO', 'M2', 'Cg1']),
    "H1b_TEX" : ("TEX", ['M1_l', 'M2'],    ['PrL', 'LO', 'VO', 'Cg1']),  # Drop M2 because its a source

    # High order is Sensory/Motor Thalamus
    # Test if (M1, S1) has more synergy than (M1, M2) or (S1, S2)
    "H2_TEX"  : ("TEX", ['M1_l', 'S1_bf'], ['Po', 'VM']),
    "H2a_TEX" : ("TEX", ['S1_bf', 'S2'],   ['Po', 'VM']),
    "H2b_TEX" : ("TEX", ['M1_l', 'M2'],    ['Po', 'VM']),
    
    # Thalamus as source
    "H3_TEX"  : ("TEX", ['Po', 'VPM'],     ['S1_bf', 'S2']),
    
    # Motor Thalamus synchronization
    "H4_TEX"  : ("TEX", ['VM', 'VL', 'LDVL'],      ['M1_l', 'M2']),
}

In [None]:
pid.hypotheses_calc_pid(dataDB, mc, hypothesesDict, h5outname, #nDropPCA=1,
                        datatypes=['bn_session', 'bn_trial'], trialType='iGO', performance='expert')

In [None]:
pid.hypotheses_plot_pid(dataDB, hypothesesDict, h5outname, datatypes=['bn_session'])

In [None]:
pid.hypotheses_calc_plot_info3D(dataDB, hypothesesDict, intervDict,
                                nBin=4, datatypes=['bn_session'], trialType='iGO', performance='expert')

# All-Distribution

**TODO**:
* [] Drop shitty sessions
* [] Try composite p-values
* [+] For Info3D, drop PCA1
* [] Consider re-doing analysis with PCA1 dropped
* Fraction Significant triplets per session
    * [+] Do regression on PID instead of Naive vs Expert
    * [ ] Binomial test fraction significant PID's above chance
    * [ ] Test if regression explained by having more iGO trials in expert
* Most significant triplets
    * [ ] Plot pvalue vs performance for top10 sessions

In [None]:
# tmp_path = root_path_data if 'root_path_data' in locals() else "./"
pwdAllH5 = '/media/alyosha/Data/TE_data/yarodata/sych_preprocessed/sych_result_multiregional_pid_df3.h5'
# pwdAllH5 = '/media/aleksejs/DataHDD/work/data/yaro/pid/sych_result_multiregional_pid_df3.h5'
# pwdAllH5 = gui_fname('h5path', './', '(*.h5)')

In [None]:
pid.plot_all_results_distribution(dataDB, pwdAllH5, plotstyle='cdf', minTrials=50)

In [None]:
pid.plot_all_frac_significant_bysession(dataDB, pwdAllH5, minTrials=50)

In [None]:
# TODO: Linear fit + pval(H0: alpha=0)
pid.plot_all_frac_significant_performance_scatter(dataDB, pwdAllH5, minTrials=50)

TODO:
* Top 10 most synergetic connections
    - Try magnitude vs avg p-value vs fraction significant
    - Plot colorbars to show fraction of sessions by mouse
* Top 10 most synergy-involved regions
    - Count fraction of significant triplets where this region is target

In [None]:
summaryDF = pid.pid_all_summary_df(pwdAllH5)

In [None]:
# Precompute fraction of significant sessions for each triplet
pidTypes = ['unique', 'red', 'syn']
mouseSignDict = {}
for keyLabel, dfSession in summaryDF.groupby(['datatype', 'phase']):
    print(keyLabel)
    mouseSignDict[keyLabel] = pid._get_pid_sign_dict(dataDB, keyLabel,
                                                     dfSession.drop(['datatype', 'phase'], axis=1),
                                                     pwdAllH5, pidTypes, minTrials=50, trialType='iGO')

In [None]:
# 1D projection: Targets that have highest fraction of sessions averaged over sources
for keyLabel, dfSession in summaryDF.groupby(['datatype', 'phase']):
    print(keyLabel)
    
    pid.plot_all_frac_significant_1D_top_n(dataDB, mouseSignDict[keyLabel], '_'.join(keyLabel), pidTypes, nTop=20)

In [None]:
# 2D projection: Source pairs that have highest fraction of sessions averaged over targets
for keyLabel, dfSession in summaryDF.groupby(['datatype', 'phase']):
    print(keyLabel)
    pid.plot_all_frac_significant_2D_avg(dataDB, mouseSignDict[keyLabel], '_'.join(keyLabel), pidTypes)

In [None]:
# 3D projection: Triplets with highest fraction of sessions
for keyLabel, dfSession in summaryDF.groupby(['datatype', 'phase']):
    print(keyLabel)
    pid.plot_all_frac_significant_3D_top_n(dataDB, mouseSignDict[keyLabel], '_'.join(keyLabel), pidTypes, nTop=20)

In [None]:
# Specific 2D projection: Fractions for all source pairs given target
for keyLabel, dfSession in summaryDF.groupby(['datatype', 'phase']):
    print(keyLabel)
    pid.plot_all_frac_significant_2D_by_target(dataDB, mouseSignDict[keyLabel], '_'.join(keyLabel), 'syn', 'VPL',
                                               vmax=1)

In [None]:
hDict = {
#     "H_ALL"  : ("REW", ['VPL', 'DG_a'], ['VM'])
    "H_ALL"  : ("REW", ['Rt', 'SuG'], ['Cpu'])
#     "H_ALL"  : ("REW", ['Cpu_1', 'VPL'], ['VL'])
#         "H_ALL"  : ("TEX", ['M2', 'S2'], ['VPL'])
}

pid.hypotheses_calc_plot_info3D(dataDB, hDict, #performance='expert',
                                datatypes=['bn_session', 'bn_trial'], trialType='iGO')

## All - Distribution - Nosession

In [None]:
import lib.analysis.pid_joint as pid_joint

In [None]:
pwdAllH5_2 = '/media/alyosha/Data/TE_data/yarodata/sych_preprocessed/sych_result_multiregional_pid_all_df.h5'
pwdAllH5_2_Rand = '/media/alyosha/Data/TE_data/yarodata/sych_preprocessed/sych_result_multiregional_pid_all_df_rand.h5'

In [None]:
dfSummary = pid_joint.pid_all_summary_df(pwdAllH5_2)

In [None]:
dfSummaryRand = pid_joint.pid_all_summary_df(pwdAllH5_2_Rand, parserName='Rand')

In [None]:
pid_joint.cdfplot(pwdAllH5_2, dfSummary)

In [None]:
pid_joint.test_avg_bits(dataDB, mc, pwdAllH5_2, dfSummary)

In [None]:
pid_joint.scatter_effsize_bits(pwdAllH5_2, dfSummary)

In [None]:
pid_joint.plot_triplets(dataDB, pwdAllH5_2, dfSummary, nTop=20)

In [None]:
pid_joint.plot_singlets(dataDB, pwdAllH5_2, dfSummary, nTop=20)

In [None]:
# TODO: Unique: Sweep over pairs of (S1,T) vs S2, as opposed to (S1,S2) vs T
pid_joint.plot_2D_avg(dataDB, pwdAllH5_2, dfSummary, dropChannels=None, avgAxis=1)

In [None]:
pid_joint.plot_2D_bytarget(dataDB, pwdAllH5_2, dfSummary, 'Rt', dropChannels=None)

In [None]:
pid_joint.plot_2D_bytarget_synergy_cluster(dataDB, pwdAllH5_2, dfSummary, 'VPL',
                                           dropChannels=[21], clusterParam=2, dropWeakChannelThr=0.1)

In [None]:
pid_joint.plot_unique_top_pairs(dataDB, pwdAllH5_2, dfSummary, nTop=20, dropChannels=None)

# Consistency

## 1. Across Mice

In [None]:
pid_joint.plot_consistency_bymouse(pwdAllH5_2, dfSummary, dropChannels=[21], performance='naive',
                                   kind='fisher', limits=[0, 1])
pid_joint.plot_consistency_bymouse(pwdAllH5_2, dfSummary, dropChannels=[21], performance='expert',
                                   kind='fisher', limits=[0, 1])

In [None]:
pid_joint.plot_consistency_byphase(pwdAllH5_2, dfSummary, dropChannels=[21], performance='naive',
                                   kind='fisher', limits=[0, 1], datatype='bn_trial')
pid_joint.plot_consistency_byphase(pwdAllH5_2, dfSummary, dropChannels=[21], performance='expert',
                                   kind='fisher', limits=[0, 1], datatype='bn_trial')
pid_joint.plot_consistency_byphase(pwdAllH5_2, dfSummary, dropChannels=[21], performance='naive',
                                   kind='fisher', limits=[0, 1], datatype='bn_session')
pid_joint.plot_consistency_byphase(pwdAllH5_2, dfSummary, dropChannels=[21], performance='expert',
                                   kind='fisher', limits=[0, 1], datatype='bn_session')

In [None]:
pid_joint.plot_consistency_bytrialtype(pwdAllH5_2, dfSummary, dropChannels=[21], performance='naive', datatype='bn_trial',
                                 trialTypes=['iGO', 'iNOGO'], kind='fisher', fisherThr=0.1, limits=[0, 1])
pid_joint.plot_consistency_bytrialtype(pwdAllH5_2, dfSummary, dropChannels=[21], performance='expert', datatype='bn_trial',
                                 trialTypes=['iGO', 'iNOGO'], kind='fisher', fisherThr=0.1, limits=[0, 1])
pid_joint.plot_consistency_bytrialtype(pwdAllH5_2, dfSummary, dropChannels=[21], performance='naive', datatype='bn_session',
                                 trialTypes=['iGO', 'iNOGO'], kind='fisher', fisherThr=0.1, limits=[0, 1])
pid_joint.plot_consistency_bytrialtype(pwdAllH5_2, dfSummary, dropChannels=[21], performance='expert', datatype='bn_session',
                                 trialTypes=['iGO', 'iNOGO'], kind='fisher', fisherThr=0.1, limits=[0, 1])

## Sanity check if PID matches pre-computed

In [None]:
for datatype in ['bn_session', 'bn_trial']:
    for intervName in intervNames:
        for trialType in [None, 'iGO', 'iNOGO']:
            for mousename in ['mvg_4']: #sorted(dataDB.mice):
                print(datatype, intervName, trialType, mousename)
                
                channelNames = dataDB.get_channel_labels(mousename)
                dataLst = dataDB.get_neuro_data({'mousename': mousename}, datatype=datatype,
                                                trialType=trialType,
                                                zscoreDim='rs', intervName=intervName)

                display(pid.pid(dataLst, mc, channelNames, ['S1_bf', 'VPL'], ['LGP'], nPerm=100, nBin=4))
                break
            break
        break
    break