In [None]:
# Standard libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from ipywidgets import IntProgress
from IPython.display import display

import statsmodels.api as sm
from statsmodels.formula.api import ols

# Append base directory
import os,sys,inspect
rootname = "pub-2020-exploratory-analysis"
thispath = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
rootpath = os.path.join(thispath[:thispath.index(rootname)], rootname)
sys.path.append(rootpath)
print("Appended root directory", rootpath)

from mesostat.utils.qt_helper import gui_fnames, gui_fpath

from lib.sych.data_fc_db_raw import DataFCDatabase
import lib.analysis.model_based_analysis as mba

%load_ext autoreload
%autoreload 2

In [None]:
# tmp_path = root_path_data if 'root_path_data' in locals() else "./"
params = {}
# params['root_path_data'] = './'
params['root_path_data'] = '/media/alyosha/Data/TE_data/yarodata/sych_preprocessed'
# params['root_path_data'] = gui_fpath('h5path', './')

In [None]:
dataDB = DataFCDatabase(params)

In [None]:
print(dataDB.mice)
print(dataDB.dataTypes)
print(dataDB.trialTypeNames)

# Model-Based Analysis

* Precond:
    - Drop 2nd order poly

* Linear models
    - AR(1)-Ridge
    - MAR(1)-Ridge
    - HAR-Ridge
    - Phases (PRE/TEX/REW/NONE)
    - HAR-Reward (Last, Last 3, Last 10)
    - Behaviour Fitting (Whisk/Lick)

* Nonlinear Models
    
* Hidden Variable Models
    - Think of accumulators like HGF, maybe sth even simpler

* Performance measures
    - L2
    - AIC/BIC/BF
    - R2/Cross-validation
    - Cross-temporal-correlation across sessions
    
**TODO**:
* Find or compute non-selected trials (too large or too short)
* Preprocess session data by setting non-selected trials to NAN
* Write function to extract all timepoints of relevant trials

In [None]:
sessions = dataDB.get_sessions('mvg_4')
sessions

## Step 1. Preprocessing

In [None]:
data, trialStartIdxs, interTrialStartIdxs, fps, trialTypes = dataDB.get_data_raw('mvg_4_2017_11_10_a')

In [None]:
data = data[:, :48]

In [None]:
# Find trials with optogenetic manipulation
trialIdxsOptogen = mba.optogen_trial_idxs(data, trialStartIdxs, fps)

In [None]:
trialStartIdxsSelected = trialStartIdxs[~trialIdxsOptogen]

In [None]:
# Set undesirable trials to NAN
data = mba.set_trials_nan(data, trialIdxsOptogen, trialStartIdxs, fps)

In [None]:
nTimeStep = len(data)
times = np.arange(nTimeStep) / fps
labels = dataDB.get_channel_labels()

In [None]:
dataFitted, dataDFF = mba.dff_poly(times, data, 3)

In [None]:
%matplotlib notebook
mba.plot_fitted_data(times, data, dataFitted, dataDFF, 34, labels)

In [None]:
%matplotlib inline
for iCh in range(48):
    mba.plot_fitted_data(times, data, dataFitted, dataDFF, iCh, labels)

## 0. Baseline

In [None]:
idxsTrg = mba.get_trial_timestep_indices(trialStartIdxsSelected, fps)
dataTrg = dataDFF[idxsTrg]

In [None]:
%matplotlib inline
mba.plot_rmse_bychannel(dataTrg, {'baseline': np.zeros(dataTrg.shape)})

## 1. AR(1) Model

In [None]:
dataSrcAr1 = mba.get_src_ar1(dataDFF, idxsTrg)
dataTrgAr1 = mba.fit_predict_bychannel(dataSrcAr1, dataTrg)

In [None]:
mba.plot_rmse_bychannel(dataTrg, {'ar(1)':dataTrgAr1}, haveLog=False)

In [None]:
plt.figure()
plt.plot(np.arange(len(dataSrcAr1))[:200] / 20, dataSrcAr1[:200, 37], label='raw')
plt.plot(np.arange(len(dataSrcAr1))[:200] / 20, dataTrgAr1[:200, 37], label='ar(1)')
plt.legend()
plt.show()

# 2. HAR Model

In [None]:
dataSrcHAR = mba.get_src_har(dataDFF, idxsTrg, [160, 3*160])

In [None]:
dataTrgHAR = mba.fit_predict_bychannel(dataSrcHAR, dataTrg, alpha=0.01)

mba.plot_rmse_bychannel(dataTrg,
                        {
                            'ar(1)'   : dataTrgAr1,
                            'HAR'     : dataTrgHAR,
                        }, haveLog=True)

## 3. MAR(1) Model

In [None]:
dataTrgMAR1 = mba.fit_predict_multivar_bychannel(dataSrcAr1, dataTrg, 0.0001)

In [None]:
mba.plot_rmse_bychannel(dataTrg,
                        {
                            'ar(1)'   : dataTrgAr1,
                            'mar(1)'  : dataTrgMAR1,
                        }, haveLog=True)

In [None]:
from mesostat.utils.pandas_helper import pd_append_row

In [None]:
labels = dataDB.get_channel_labels()
df = pd.DataFrame(columns=['mousename', 'session', 'method'] + labels)

In [None]:
for mousename in dataDB.mice:
    for session in dataDB.get_sessions(mousename):
        if session not in list(df['session']):
            print(mousename, session)
            data, trialStartIdxs, interTrialStartIdxs, fps, trialTypes = dataDB.get_data_raw(session)
            data = data[:, :48]

            # Find trials with optogenetic manipulation
            trialIdxsOptogen = mba.optogen_trial_idxs(data, trialStartIdxs, fps)

            trialStartIdxsSelected = trialStartIdxs[~trialIdxsOptogen]

            # Set undesirable trials to NAN
            data = mba.set_trials_nan(data, trialIdxsOptogen, trialStartIdxs, fps)

            nTimeStep = len(data)
            times = np.arange(nTimeStep) / fps

            dataFitted, dataDFF = mba.dff_poly(times, data, 3)


            idxsTrg = mba.get_trial_timestep_indices(trialStartIdxsSelected, fps)
            dataTrg = dataDFF[idxsTrg]

            L2 = mba.rms(dataTrg, axis=0)

            # AR(1)
            print('AR(1)')
            dataSrcAr1 = mba.get_src_ar1(dataDFF, idxsTrg)
            dataTrgAr1 = mba.fit_predict_bychannel(dataSrcAr1, dataTrg)
            relRmsAr1 = mba.rms(dataTrgAr1 - dataTrg, axis=0) / L2
            df = pd_append_row(df, [mousename, session, 'ar(1)'] + list(relRmsAr1))

            print('HAR')
            dataSrcHAR = mba.get_src_har(dataDFF, idxsTrg, [160, 3*160])
            dataTrgHAR = mba.fit_predict_bychannel(dataSrcHAR, dataTrg, alpha=0.01)
            relRmsHAR = mba.rms(dataTrgHAR - dataTrg, axis=0) / L2
            df = pd_append_row(df, [mousename, session, 'har'] + list(relRmsHAR))

            print('MAR(1)')
            dataTrgMAR1 = mba.fit_predict_multivar_bychannel(dataSrcAr1, dataTrg, 0.0001)
            relRmsMAR1 = mba.rms(dataTrgMAR1 - dataTrg, axis=0) / L2
            df = pd_append_row(df, [mousename, session, 'mar(1)'] + list(relRmsMAR1))

In [None]:
labels = dataDB.get_channel_labels()
df = pd.DataFrame(columns=['mousename', 'session', 'trialType', 'method'] + labels)

In [None]:
for mousename in dataDB.mice:
    for session in dataDB.get_sessions(mousename):
        if session not in list(df['session']):
            data, trialStartIdxs, interTrialStartIdxs, fps, trialTypes = dataDB.get_data_raw(session)
            data = data[:, :48]

            # Find trials with optogenetic manipulation
            trialIdxsOptogen = mba.optogen_trial_idxs(data, trialStartIdxs, fps)
            
            for trialType in [None, 'iGO', 'iNOGO', 'iMISS', 'iFA']:
                trialIdxsType = trialTypes == trialType if trialType is not None else np.ones(len(trialTypes))
                trialIdxsSelected = np.logical_and(trialIdxsType, ~trialIdxsOptogen)
                
                
                nTrialThis = np.sum(trialIdxsSelected)
                print(mousename, session, trialType, nTrialThis)
                if nTrialThis < 40:
                    print('-- Too few trials, skipping')
                    continue
                
                
                trialStartIdxsSelected = trialStartIdxs[trialIdxsSelected]

                # Set undesirable trials to NAN
                data = mba.set_trials_nan(data, trialIdxsOptogen, trialStartIdxs, fps)

                nTimeStep = len(data)
                times = np.arange(nTimeStep) / fps

                dataFitted, dataDFF = mba.dff_poly(times, data, 3)


                idxsTrg = mba.get_trial_timestep_indices(trialStartIdxsSelected, fps)
                dataTrg = dataDFF[idxsTrg]

                L2 = mba.rms(dataTrg, axis=0)

                # AR(1)
                print('AR(1)')
                dataSrcAr1 = mba.get_src_ar1(dataDFF, idxsTrg)
                dataTrgAr1 = mba.fit_predict_bychannel(dataSrcAr1, dataTrg)
                relRmsAr1 = mba.rms(dataTrgAr1 - dataTrg, axis=0) / L2
                df = pd_append_row(df, [mousename, session, trialType, 'ar(1)'] + list(relRmsAr1))

                print('HAR')
                dataSrcHAR = mba.get_src_har(dataDFF, idxsTrg, [160, 3*160])
                dataTrgHAR = mba.fit_predict_bychannel(dataSrcHAR, dataTrg, alpha=0.01)
                relRmsHAR = mba.rms(dataTrgHAR - dataTrg, axis=0) / L2
                df = pd_append_row(df, [mousename, session, trialType, 'har'] + list(relRmsHAR))

                print('MAR(1)')
                dataTrgMAR1 = mba.fit_predict_multivar_bychannel(dataSrcAr1, dataTrg, 0.0001)
                relRmsMAR1 = mba.rms(dataTrgMAR1 - dataTrg, axis=0) / L2
                df = pd_append_row(df, [mousename, session, trialType, 'mar(1)'] + list(relRmsMAR1))

In [None]:
df.to_hdf('model_fitting_l2.h5', 'ar_har_mar')

In [None]:
df

In [None]:
df = pd.read_hdf('model_fitting_l2.h5', 'ar_har_mar')

In [None]:
colsData = [col for col in df.columns if col not in {'mousename', 'session', 'trialType', 'method'}]

In [None]:
from mesostat.utils.pandas_helper import pd_query
import seaborn as sns

In [None]:
for trialType, dfTT in df.groupby(['trialType']):
    fig, ax = plt.subplots(ncols=4, figsize=(4*4, 4), tight_layout=True)
    fig.suptitle(trialType)
    
    iMouse = 0
    for mousename, dfMouse in dfTT.groupby(['mousename']):
        for method, dfMethod in dfMouse.groupby(['method']):
            dfRez = dfMethod.drop(['mousename', 'session', 'trialType', 'method'], axis=1)
            
            data=np.array(dfRez).T
            dataMu = np.mean(data, axis=1)
            dataStd = np.std(data, axis=1)
            
            x = np.arange(data.shape[0])
            ax[iMouse].plot(x, dataMu, label=method)
            ax[iMouse].fill_between(x, dataMu-dataStd, dataMu+dataStd, alpha=0.2)
            
        ax[iMouse].set_title(mousename)
        ax[iMouse].legend()
        iMouse += 1
    
    plt.show()
    
    
    

# for key1, dataKey in df.groupby(['session', 'trialType']):
#     print(key1)
#     for idx, row in dataKey.iterrows():
#         print('--', row['method'], np.round(np.mean(row[colsData]), 2))