In [None]:
# Standard libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from ipywidgets import IntProgress
from IPython.display import display

# Append base directory
import os,sys,inspect
rootname = "pub-2020-exploratory-analysis"
thispath = os.getcwd()
# thispath = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
rootpath = os.path.join(thispath[:thispath.index(rootname)], rootname)
sys.path.append(rootpath)
print("Appended root directory", rootpath)

from mesostat.utils.qt_helper import gui_fnames, gui_fpath
from mesostat.metric.metric import MetricCalculator
from mesostat.utils.hdf5_io import DataStorage

from lib.gallerosalas.data_fc_db_aud_raw import DataFCDatabase
import lib.analysis.coactivity as coactivity
from lib.common.visualization import merge_image_sequence_movie

%load_ext autoreload
%autoreload 2

In [None]:
params = {}
# params['root_path_data']  = gui_fpath("Path to data collection",  './')
params['root_path_data'] = '/media/alyosha/Data/TE_data/yasirdata_aud_raw/'
# params['root_path_data'] = '/home/alyosha/data/yasirdata_aud_raw/'
# params['root_path_data'] = '/media/aleksejs/DataHDD/work/data/yasir/yasirdata_aud_raw'

In [None]:
dataDB = DataFCDatabase(params)

In [None]:
ds = DataStorage('gallerosalas_result_coactivity.h5')

In [None]:
mc = MetricCalculator(serial=True, verbose=False)

In [None]:
print('mice', dataDB.mice)
print('nSessions', len(dataDB.sessions))
print('datatypes', dataDB.get_data_types())
print('nChannel', dataDB.get_nchannels('mou_5'))

In [None]:
from mesostat.utils.pandas_helper import pd_query, pd_is_one_row

In [None]:
areas = sorted(set(dataDB.channelAreasDF['Area']))
areaDict = {a : [] for a in areas}

for iCh, chName in enumerate(dataDB.get_channel_labels()):
    rez = pd_is_one_row(pd_query(dataDB.channelAreasDF, {'LOrig' : chName}))[1]
    areaDict[rez['Area']] += [iCh]

fig, ax = plt.subplots(figsize=(4,4))
dataDB.plot_area_clusters(fig, ax, areaDict, haveLegend=True)

# 1. Significance

## 1.1. Correlation plots


## 1.2 PCA exploration

## 1.3. Highly uncorrelated channels

In [None]:
argSweepDict = {
    'datatype':  ['bn_trial', 'bn_session'],
    'trialType': ['None', 'Hit', 'CR', 'Miss', 'FA'],
    'intervName': 'auto'
}

argSweepDictSubpre = {
    'trialType': ['None', 'Hit', 'CR', 'Miss', 'FA'],
    'intervName': 'auto'
}

In [None]:
exclQueryLst = [
    {'datatype' : 'bn_trial', 'intervName' : 'PRE'},  # Baseline normalized
#     {'mousename' : 'mou_6', 'intervName' : 'REW'},    # No reward for this mouse
]

In [None]:
argSweepDictMouse = {
    'datatype' : ['bn_trial', 'bn_session'],
    'intervName' : 'auto'
}

coactivity.compute_store_corr_mouse(dataDB, ds, dataDB.get_trial_type_names(),
                                    skipExisting=False, exclQueryLst=exclQueryLst, **argSweepDictMouse)

In [None]:
coactivity.plot_corr_mouse(dataDB, mc, 'corr', 'intervName', nDropPCA=None, dropChannels=[16, 26],
                           haveBrain=True, haveMono=False, exclQueryLst=exclQueryLst, **argSweepDict)

In [None]:
coactivity.plot_corr_mouse(dataDB, mc, 'corr', 'trialType', nDropPCA=None, dropChannels=[16, 26],
                           haveBrain=True, haveMono=False, exclQueryLst=exclQueryLst, **argSweepDict)

In [None]:
coactivity.plot_corr_mousephase_subpre(dataDB, mc, 'corr', nDropPCA=None, dropChannels=[16, 26],
                                       exclQueryLst=exclQueryLst, **argSweepDictSubpre)

In [None]:
coactivity.plot_corr_mousephase_submouse(dataDB, mc, 'corr', dropChannels=[16, 26],
                                         exclQueryLst=exclQueryLst, **argSweepDict)

**Drop first PCA and explore result**

In [None]:
coactivity.plot_corr_mouse(dataDB, mc, 'corr', 'intervName', nDropPCA=1, dropChannels=[16, 26],
                           haveBrain=True, haveMono=False, exclQueryLst=exclQueryLst, **argSweepDict)

In [None]:
coactivity.plot_corr_mouse(dataDB, mc, 'corr', 'trialType', nDropPCA=1, dropChannels=[16, 26],
                           haveBrain=True, haveMono=False, exclQueryLst=exclQueryLst, **argSweepDict)

In [None]:
coactivity.plot_corr_mousephase_subpre(dataDB, mc, 'corr', nDropPCA=1, dropChannels=[16, 26],
                                       exclQueryLst=exclQueryLst, **argSweepDictSubpre)

In [None]:
coactivity.plot_corr_mousephase_submouse(dataDB, mc, 'corr', dropChannels=[16, 26], nDropPCA=1,
                                         exclQueryLst=exclQueryLst, **argSweepDict)

## 1.4 Plot correlation movies

In [None]:
coactivity.plot_corr_movie_mousetrialtype(dataDB, mc, 'corr', exclQueryLst=exclQueryLst,
                                          haveDelay=True, trialType='auto',
                                          datatype=['bn_trial', 'bn_session'])

In [None]:
merge_image_sequence_movie('corr_mouseTrialType_dropPCA_None_bn_trial_', '.png', 0, 160,
                           trgPathName=None, deleteSrc=True)

In [None]:
merge_image_sequence_movie('corr_mouseTrialType_dropPCA_None_bn_session_', '.png', 0, 160,
                           trgPathName=None, deleteSrc=True)

In [None]:
coactivity.plot_corr_movie_mousetrialtype(dataDB, mc, 'corr', exclQueryLst=exclQueryLst,
                                          haveDelay=True, nDropPCA=1, trialType='auto',
                                          datatype=['bn_trial', 'bn_session'])

In [None]:
merge_image_sequence_movie('corr_mouseTrialType_dropPCA_1_bn_trial_', '.png', 0, 160,
                           trgPathName=None, deleteSrc=True)

In [None]:
merge_image_sequence_movie('corr_mouseTrialType_dropPCA_1_bn_session_', '.png', 0, 160,
                           trgPathName=None, deleteSrc=True)

# 2. Consistency

In [None]:
from mesostat.utils.pandas_helper import pd_append_row, pd_pivot, pd_is_one_row, pd_query, pd_first_row
incr_row = lambda row, incr: {k+incr: v for k, v in dict(row).items()}
merge_rows = lambda r1, r2: pd.Series({**incr_row(r1, '1'), **incr_row(r2, '2')})

# Get data
df = ds.list_dsets_pd()
dfMouse = pd_query(df, {'name': 'corr_mouse'})

dfMouse

# Average out trials
rezLst = []
for idx, row in dfMouse.iterrows():
    cc = ds.get_data(row['dset'])
    rezLst += [cc.flatten()]

# Compute CC, assemble outer product dict
dfQuadDict = {'bn_session': pd.DataFrame(), 'bn_trial': pd.DataFrame()}

for datatype in dfQuadDict.keys():
    print(datatype)
    dfDataType = pd_query(dfMouse, {'datatype': datatype})
    
    for i1, (idx1, row1) in enumerate(dfDataType.iterrows()):
        print(i1)
        for i2, (idx2, row2) in enumerate(dfDataType.iterrows()):
            r1 = row1.drop(['name', 'datetime', 'shape', 'dset'])
            r2 = row2.drop(['name', 'datetime', 'shape', 'dset'])
            rnew = merge_rows(r1, r2)
            rnew['value'] = np.corrcoef(rezLst[i1], rezLst[i2])[0, 1]

            dfQuadDict[datatype] = dfQuadDict[datatype].append(rnew, ignore_index=True)

In [None]:
from mesostat.visualization.mpl_matrix import plot_df_2D_outer_product

for datatype, dfThis in dfQuadDict.items():
    fig, ax = plt.subplots()
    plot_df_2D_outer_product(ax, pd_query(dfThis, {'mousename1': 'mou_5', 'mousename2': 'mou_5'}),
                             ['intervName1', 'trialType1'],
                             ['intervName2', 'trialType2'],
                             'value',
                             vmin=-1,
                             vmax=1,
                             orderDict = {'intervName1': dataDB.get_interval_names(),
                                          'intervName2': dataDB.get_interval_names()}
                            )
    plt.savefig('coactivity_similarity_'+datatype+'.svg')
    plt.show()

## 2.1. PCA consistency over mice
### 2.1.1. Angle-based consistency

Tasks
  * Explained variance by phase/session/mouse/trialType
     * Do not separate phases, its meaningless. Compute PCA for all timesteps, then see proj differences for phases
     * Implement HAC correction

  * Global PCA shifts vs session

Approaches:
  * Eval PCA over all data, select strongest components, plot components as function of cofound
  * Eval PCA for cofounds, compare PCA
  
**Plots**:
* Cosine-squared matrix $C^2_{ij} = (R^{1}_{ik}R^{2}_{jk})^2$, where $R^l$ is the PCA-transform
* Consistency metric $E = e^1_i e^2_j C^2_{ij}$, where $e^l$ are the eigenvalues

**Problem**:
The consistency metric $E$ has all necessary ingredients (angles, eigenvalues), but it is not mathematically clear that it behaves the desired way. Solid theory is required for this metric to be useful.

**Alternative approach**:
Try consistency metric $H(\frac{C^2_{ij}}{N})$. Should be great at measuring the sparsity of basis coupling. The challenge is to include eigenvalue priority into this metric, since spread of weak eigenvalues is not as relevant as spread of strong ones.

In [None]:
for datatype in ['bn_session', 'bn_trial']:
    coactivity.plot_pca_alignment_bymouse(dataDB, datatype=datatype, trialType=None, intervName='DEL')

### 2.1.2. Eigenvalue-based consistency

* Let $x_1$, $x_2$ be some datasets
* Let $R_1$, $R_2$ be the corresponding PCA-transforms 
* Find total variances
    - $V_1 = \sum_i eig_i(x_1) = tr(cov(x_1)) = \sum_i cov_{ii}(x_1)$
    - $V_2 = \sum_i eig_i(x_2) = tr(cov(x_2)) = \sum_i cov_{ii}(x_2)$
* Find explained variances
    - $e_1 = eval(cov(x_1)) = diag(cov(R_1 x_1))$
    - $e_2 = eval(cov(x_2)) = diag(cov(R_2 x_2))$
* Find explained variances using wrong bases
    - $e_{12} = diag(cov(R_2 x_1))$
    - $e_{21} = diag(cov(R_1 x_2))$
* Find representation errors in explained variance ratios
    - $\epsilon_1 = \frac{\sum_i |e^1_i - e^{12}_i|}{2 V_1}$
    - $\epsilon_2 = \frac{\sum_i |e^2_i - e^{21}_i|}{2 V_2}$



* TODO: iter trialType=[hit, cr, all]
* TODO: iter perf=[naive,expert,all]

In [None]:
coactivity.plot_pca_consistency(dataDB)

In [None]:
coactivity.plot_pca_consistency(dataDB, dropFirst=1)

## 2.2. PCA consistency over phases
### 2.2.1 Angle-based consistency

In [None]:
for datatype in ['bn_session', 'bn_trial']:
    coactivity.plot_pca_alignment_byphase(dataDB, intervNames=['TEX', 'REW'], datatype=datatype, trialType=None)