In [None]:
# Standard libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from ipywidgets import IntProgress
from IPython.display import display

import statsmodels.api as sm
from statsmodels.formula.api import ols

# Append base directory
import os,sys,inspect
rootname = "pub-2020-exploratory-analysis"
thispath = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
rootpath = os.path.join(thispath[:thispath.index(rootname)], rootname)
sys.path.append(rootpath)
print("Appended root directory", rootpath)

from mesostat.utils.qt_helper import gui_fnames, gui_fpath
from mesostat.metric.metric import MetricCalculator
from mesostat.utils.hdf5_io import DataStorage
from mesostat.stat.anova import as_pandas, as_pandas_lst, anova_homebrew

from lib.gallerosalas.data_fc_db_raw import DataFCDatabase
# from lib.common.metric_helper import metric_by_session

%load_ext autoreload
%autoreload 2

In [None]:
params = {}
# params['root_path_data']  = gui_fpath("Path to data collection",  './')
params['root_path_data'] = '/media/alyosha/Data/TE_data/yasirdata_raw/'
# params['root_path_data'] = '/media/aleksejs/DataHDD/work/data/yasir/yasirdata_raw'

In [None]:
dataDB = DataFCDatabase(params)

In [None]:
ds = DataStorage('gallerosalas_result_individual_region.h5')

In [None]:
mc = MetricCalculator(serial=True, verbose=False)

In [None]:
print('mice', dataDB.mice)
print('nSessions', len(dataDB.sessions))
print('datatypes', dataDB.get_data_types())
print('nChannel', dataDB.get_nchannels('mou_5'))

# Analysis of Variance

* Across sessions
    - Explained by performance
* Across channels, trials, timesteps
    - Explained by trial type

Things to understand:
* How to compare different rows?
* What models make sense?
* Try linear mixed models?

In [None]:
#trialTypeNames = dataDB.get_trial_type_names()
trialTypeNames = ['Hit', 'CR', 'Miss', 'FA']
intervNames = ['TEX', 'DEL', 'REW']

dfDict = {'bn_session': {}, 'bn_trial': {}}
for datatype in dfDict.keys():
    for mousename in dataDB.mice:
        sessions = dataDB.get_sessions(mousename)
        dfThis = pd.DataFrame()
        for session in sessions:
            for trialType in trialTypeNames:
                dataTrialLst = []
                for intervName in intervNames:
                    if (mousename != 'mou_6') or (intervName != 'REW'):                
                        data = dataDB.get_neuro_data({'session' : session}, datatype=datatype,
                                                     trialType=trialType, intervName=intervName)[0]
                        data = np.mean(data, axis=1)  # Average over timesteps
                        dataDF = as_pandas(data, ('trials', 'channels'))
                        dataDF['trialType'] = trialType
                        dataDF['interval'] = intervName
                        dataDF['session'] = session
                        dfThis = dfThis.append(dataDF, ignore_index=True)

        dfThis = dfThis.drop('trials', axis=1)
        dfDict[datatype][mousename] = dfThis

In [None]:
def pandas_display_2digit(arg):
    tmp = pd.options.display.float_format
    pd.options.display.float_format = "{:,.2f}".format
    display(arg)
    pd.options.display.float_format = tmp

In [None]:
model = '''
    rez ~ C(channels)
    + C(trialType)
    + C(interval)
    + C(session)
    + C(trialType)*C(session)
    + C(trialType)*C(channels)
    + C(interval)*C(channels)
    + C(interval)*C(trialType)
'''

# Session-wide
for datatype in dfDict.keys():
    for mousename in sorted(dataDB.mice):
        print(datatype, mousename)
        linModel = ols(model, data=dfDict[datatype][mousename]).fit()
        dfRez = sm.stats.anova_lm(linModel, typ=1)
        del dfRez['mean_sq']
        pandas_display_2digit(dfRez)

### Mouse-cumulative model

In [None]:
dfDictCumulative = {}
for datatype in dfDict.keys():
    dfDictCumulative[datatype] = pd.DataFrame()
    
    for mousename, dfThis in dfDict[datatype].items():
        dfNew = dfThis.drop(['session'], axis=1)
        dfNew['mousename'] = mousename
        dfDictCumulative[datatype] = dfDictCumulative[datatype].append(dfNew)

In [None]:
dfRezDict = {}
for datatype, dfCumulative in dfDictCumulative.items():
    print(datatype)
    
    model = '''
        rez ~ C(channels)
        + C(trialType)
        + C(interval)
        + C(mousename)
        + C(mousename)*C(channels)
        + C(mousename)*C(trialType)
        + C(mousename)*C(interval)
        + C(trialType)*C(channels)
        + C(interval)*C(channels)
        + C(interval)*C(trialType)
    '''

    linModel = ols(model, data=dfCumulative).fit()
    dfRez = sm.stats.anova_lm(linModel, typ=1)
    del dfRez['mean_sq']

    # Calculate relative sum squares
    dfRez['r2'] = dfRez['sum_sq'] / np.sum(dfRez['sum_sq'])

    # Move row index into a column
    dfRez.reset_index(inplace=True)

    # Drop residual
    dfRez.drop(dfRez.tail(1).index, inplace=True)
    pandas_display_2digit(dfRez)
    
    dfRezDict[datatype] = dfRez

In [None]:
import seaborn as sns

In [None]:
for datatype, dfRez in dfRezDict.items():
    print(datatype, np.sum(dfRez['r2']))
    plt.figure()
    plt.suptitle(datatype)
    sns.barplot(data=dfRez, x='index', y='r2')
    plt.xticks(rotation=90)
    plt.tight_layout()
    plt.savefig('anova_'+datatype + '.svg')
    plt.ylim(0, 0.07)
    plt.show()

### Across-sessions

In [None]:
nMice = len(dataDB.mice)
fig, ax = plt.subplots(ncols = nMice, figsize=(5*nMice, 5))

model = '''
rez ~ C(channels)+C(trialType)+C(interval)
'''

for iMouse, mousename in enumerate(sorted(dataDB.mice)):
    dfThis = dfDict[mousename]
    print(mousename)
    sessions = dataDB.get_sessions(mousename)
    plotData = []
    for session in sessions:
        dfSession = dfThis[dfThis['session'] == session]
        linModel = ols(model, data=dfSession).fit()
        rezStat = sm.stats.anova_lm(linModel, typ=1)
        rezStat = rezStat.drop('Residual')
        plotData += [np.array(rezStat['mean_sq'])]

    names = ['channels', 'trialType', 'interval']
    plotData = np.array(plotData).T


    ax[iMouse].set_ylabel('mean_sq')
    for name, x in zip(names, plotData):
        ax[iMouse].semilogy(x, label=name)

    ax[iMouse].set_xticks(np.arange(len(sessions)))
    ax[iMouse].set_xticklabels(sessions, rotation=90)
    ax[iMouse].legend()

plt.savefig('pics/ANOVA_by_session.png')
plt.show()
plt.close()