In [None]:
# Standard libraries
from copy import deepcopy

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interactive, IntProgress
from IPython.display import display

from sklearn.linear_model import RidgeClassifier, LogisticRegression
from sklearn.neural_network import MLPClassifier

# Append base directory
import os,sys,inspect
rootname = "chernysheva-tmaze-analysis-2020"
thispath = os.getcwd()
# thispath = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
rootpath = os.path.join(thispath[:thispath.index(rootname)], rootname)
sys.path.append(rootpath)
print("Appended root directory", rootpath)

# User libraries
from mesostat.stat.classification import binary_classifier
# from mesostat.stat.connectomics import offdiag_1D
from mesostat.utils.qt_helper import gui_fnames, gui_fpath
from mesostat.utils.pandas_helper import merge_df_from_dict
from mesostat.utils.arrays import numpy_merge_dimensions
from mesostat.visualization.mpl_matrix import imshow
from mesostat.stat.testing.quantity_test import test_quantity

from src.lib.data_db import BehaviouralNeuronalDatabase
import src.lib.plots_lib as plots_lib
import src.lib.table_lib as table_lib
import src.lib.plots_pca as plots_pca
from src.lib.metric_wrapper import metric_by_selector
from src.lib.extra_metrics import num_non_zero_std, num_sample
from src.lib.significant_cells_lib import SignificantCells

%load_ext autoreload
%autoreload 2

In [None]:
# tmp_path = root_path_data if 'root_path_data' in locals() else "./"
params = {}
# prefix = '/media/alyosha/Data/TE_data/mariadata/'
prefix = '/home/alyosha/data/maria/'
# prefix = '/media/aleksejs/DataHDD/work/data/maria/'

#params['root_path_data']  = gui_fpath("Path to data files", "./")
params['root_path_dff'] = prefix + 'dff/'
params['root_path_deconv'] = prefix + 'deconv/'

In [None]:
dataDB = BehaviouralNeuronalDatabase(params)

In [None]:
dataDB.read_neuro_files()

In [None]:
dataDB.read_behavior_files()

# 1. Binary Classification

## 1.1 Phase-average classification

**Goal**: Train a classifier to discriminate between two cases (L/R), (C/M)

**Versions**:
* **4A. Mouse-wise**: Choose metric that is independent of number of samples, calc individually for each mouse
* **4B. All-mice**: Choose metric that has fixed shape per trial, calc for all mice

**Advantages**:
* Can theoretically make use of vector metrics, by studying their combinations

**Problems**:
* High overfitting. Performance for train and test dramatically different.

**TODO**:
* Research further into regularization. Try more sophisticated estimators (e.g. )
* Impl classification by phase

In [None]:
def get_classifier(name, C):
    if name == 'LogL1':
        return LogisticRegression(max_iter=10000, C=C, penalty='l1', solver='liblinear')
    elif name == 'LogL2':
        return LogisticRegression(max_iter=10000, C=C, solver='lbfgs')
    elif name == 'LinL2':
        return RidgeClassifier(max_iter=10000, alpha=C)
    elif name == 'MPLL2':
        return MLPClassifier(alpha=C, max_iter=10000, hidden_layer_sizes=(200, 50))

def plot_classification_results(ax, CLst, df, haveLog=True):
    ax.plot(CLst, df['accTrain'], label='train')
    ax.plot(CLst, df['accTest'], label='test')
    ax.axhline(y=df['accNaive'][0], linestyle='--', color='r', label='chance')
    ax.set_ylim([0,1])
    ax.legend()
    
    if haveLog:
        ax.set_xscale('log')
    
def cross_validate(dataDB, datatype, selector, queryDict, condition,
                   metricName, classifierName):
    if condition == 'performance':
        condValA = 'Correct'
        condValB = 'Mistake'
    else:
        condValA = 'L'
        condValB = 'R'
    
    CLst = 10.0**np.arange(-7, 15)

    nMice = len(dataDB.mice)
    fig, ax = plt.subplots(ncols=nMice, figsize=(5*nMice, 5))
    for iMouse, mousename in enumerate(sorted(dataDB.mice)):
        print('Doing mouse', mousename)
        
        queryDictA = {'datatype' : datatype, 'mousename' : mousename,
                      condition : condValA}
        queryDictB = {'datatype' : datatype, 'mousename' : mousename,
                      condition : condValB}
        queryDictA = {**queryDictA, **queryDict}
        queryDictB = {**queryDictB, **queryDict}

        dataA = metric_by_selector(dataDB, queryDictA, metricName, 'rp', selector, {})
        dataB = metric_by_selector(dataDB, queryDictB, metricName, 'rp', selector, {})
        
#         dataA = numpy_merge_dimensions(dataA, 1, 3)
#         dataB = numpy_merge_dimensions(dataB, 1, 3)
        
        print('shapes', dataA.shape, dataB.shape)
    
        if len(dataA) < 10 or len(dataB) < 10:
            print('--too few samples, skipping')
        else:
            rezLst = []
            for C in CLst:
                #print(C)
                classifier = get_classifier(classifierName, C)
                rezLst += [binary_classifier(dataA, dataB, classifier, havePVal=True,
                                             method='kfold', balancing=True)]

            df = pd.DataFrame(rezLst)
    #         display(df)

            ax[iMouse].set_title(mousename)
            plot_classification_results(ax[iMouse], CLst, df)

    plotSuffixLst = [datatype, condition] + list(selector.values()) + \
                    [str(v) for v in queryDict.values()] + [metricName, classifierName]

    plt.savefig('crossval_avg_' + '_'.join(plotSuffixLst) + '.pdf', dpi=300)
    plt.close()

In [None]:
for datatype in ['deconv', 'raw']:
    for classifierName in ['LinL2', 'LogL2']:#, 'MPLL2']:
    #     for interval in [8, 7, 6]:
        for semiphase in ['M2', 'M1']:
            selector = {'semiphase' : semiphase}
            for metricName in ['mean']: #, 'temporal_basis']:
                cross_validate(dataDB, datatype, selector,
                               {'performance': 'Correct'},
                               'direction',
                               metricName, classifierName)
                cross_validate(dataDB, datatype, selector,
                               {'direction': 'L'},
                               'performance',
                               metricName, classifierName)
                cross_validate(dataDB, datatype, selector,
                               {'direction': 'R'},
                               'performance',
                               metricName, classifierName)

In [None]:
queryDictL = {'datatype' : 'deconv', 'mousename' : 'm061', 'direction' : 'L'}
queryDictR = {'datatype' : 'deconv', 'mousename' : 'm061', 'direction' : 'R'}

dataL = metric_by_selector(dataDB, queryDictL, 'mean', 'rp', {'semiphase' : 'M2'}, {})
dataR = metric_by_selector(dataDB, queryDictR, 'mean', 'rp', {'semiphase' : 'M2'}, {})

dataL = numpy_merge_dimensions(dataL, 1, 3)
dataR = numpy_merge_dimensions(dataR, 1, 3)

print('shapes', dataL.shape, dataR.shape)

classifier = get_classifier('LinL2', 1.0E+7)
rez = binary_classifier(dataL, dataR, classifier, havePVal=True, method='looc')

print(rez)

## 1.2 Resampled time classification

In [None]:
dataDB.mice
from mesostat.utils.signals.resample import resample_kernel

In [None]:
def durationDistr(data, showHist=True):
    dur = [d.shape[1] for d in data]
    if showHist:
        plt.figure()
        plt.hist(dur)
        plt.show()
    
    return int(np.mean(dur))

def resample2D(dataPS, nTrg):
    nSrc = dataPS.shape[1]
    x1 = np.linspace(0, 1, nSrc)
    x2 = np.linspace(0, 1, nTrg)
    ker = resample_kernel(x1, x2)
    return dataPS.dot(ker.T)

def get_classification_data(selector, queryDict, condition):
    if condition == 'performance':
        condValA = 'Correct'
        condValB = 'Mistake'
    else:
        condValA = 'L'
        condValB = 'R'
        
    # Get data, calculate average duration
    queryA = {**queryDict, **{condition : condValA}}
    queryB = {**queryDict, **{condition : condValB}}
    dataA = dataDB.get_data_from_selector(selector, queryA)
    dataB = dataDB.get_data_from_selector(selector, queryB)
    avgDur = durationDistr(dataA + dataB, showHist=False)
    
    print(selector, queryDict, condition)
    print('--average duration', avgDur)
    
    # Resample data, assemble 2D array
    dataA = np.array([resample2D(d, avgDur) for d in dataA])
    dataB = np.array([resample2D(d, avgDur) for d in dataB])
    return dataA, dataB


def find_optimal_hyperparameter(selector, queryDict, condition):
    dataA, dataB = get_classification_data(selector, queryDict, condition)
    
    # Test binary classifier on one timestep to find optimal hyperparameter
    CLst = 10.0**np.arange(-7, 15)
    rezLst = []
    for C in CLst:
        print(C)
        classifier = get_classifier("LogL2", C)
        rezLst += [binary_classifier(dataA[:, :, 0], dataB[:, :, 0],
                                     classifier, havePVal=True, method='looc',
                                     balancing=True)]
        
    df = pd.DataFrame(rezLst)
    fig, ax = plt.subplots(figsize=(4,4))
    plot_classification_results(ax, CLst, df)
    plt.show()


def temporal_classification(selector, queryDict, condition, C=0.1):
    dataA, dataB = get_classification_data(selector, queryDict, condition)
    
    nA = len(dataA)
    nB = len(dataB)
    
    print('nA', nA, 'nB', nB)
    
    if (nA < 10) or (nB < 10):
        print('--Too few samples, skipping')
    else:
        # Compute classification
        classifier = get_classifier("LogL2", C)
        p = IntProgress(value=0, max=dataA.shape[2]-1)
        display(p)
        rezSweep = []
        for iTime in range(dataA.shape[2]):
        #     print(iTime)
            rezSweep += [binary_classifier(dataA[:, :, iTime], dataB[:, :, iTime],
                                           classifier, havePVal=True, method='kfold',
                                           balancing=True)]
            p.value += 1

        dfSweep = pd.DataFrame(rezSweep)
        fig, ax = plt.subplots(figsize=(4,4))
        plot_classification_results(ax, np.arange(dataA.shape[2]), dfSweep, haveLog=False)

        plotSuffixLst = [condition] + list(selector.values()) + [str(v) for v in queryDict.values()]

        plt.savefig('classification_' + '_'.join(plotSuffixLst) + '.pdf', dpi=300)
        plt.close()

In [None]:
dataDB.get_phasetype_keys('semiphase', performance='Correct')

In [None]:
# Get data, calculate average duration
queryL = {"datatype" : "deconv", "mousename" : "m060", "direction" : "L"}
queryR = {"datatype" : "deconv", "mousename" : "m060", "direction" : "R"}
dataL = dataDB.get_data_from_selector({"semiphase" : "M2"}, queryL)
dataR = dataDB.get_data_from_selector({"semiphase" : "M2"}, queryR)
dataAll = dataL + dataR
avgDur = durationDistr(dataAll)

In [None]:
avgDur

In [None]:
# Resample data, assemble 2D array
dataL = np.array([resample2D(d, avgDur) for d in dataL])
dataR = np.array([resample2D(d, avgDur) for d in dataR])

In [None]:
# Test binary classifier on one timestep to find optimal hyperparameter
CLst = 10.0**np.arange(-7, 15)
rezLst = []
for C in CLst:
    print(C)
    classifier = get_classifier("LogL2", C)
    rezLst += [binary_classifier(dataL[:, :, 0], dataR[:, :, 0], classifier, havePVal=True, method='looc', balancing=True)]

In [None]:
df = pd.DataFrame(rezLst)
fig, ax = plt.subplots(figsize=(4,4))
plot_classification_results(ax, CLst, df)
plt.show()

In [None]:
# Compute classification
classifier = get_classifier("LogL2", 0.1)
p = IntProgress(value=0, max=dataL.shape[2]-1)
display(p)
rezSweep = []
for iTime in range(dataL.shape[2]):
#     print(iTime)
    rezSweep += [binary_classifier(dataL[:, :, iTime], dataR[:, :, iTime],
                                   classifier, havePVal=True, method='looc',
                                   balancing=True)]
    p.value += 1

In [None]:
dfSweep = pd.DataFrame(rezSweep)
fig, ax = plt.subplots(figsize=(4,4))
plot_classification_results(ax, np.arange(dataL.shape[2]), dfSweep, haveLog=False)
plt.show()

In [None]:
for semiphase in ['M1', 'M2']:
    for datatype in ['raw', 'deconv']:
        for mousename in dataDB.mice:
            queryDict1 =  {"datatype" : datatype,
                          "mousename" : mousename,
                          "direction" : "L"}
            queryDict2 =  {"datatype" : datatype,
              "mousename" : mousename,
              "direction" : "R"}
            queryDict3 =  {"datatype" : datatype,
              "mousename" : mousename,
              "performance" : "Correct"}
            
            temporal_classification({'semiphase': semiphase}, queryDict1, 'performance', C=0.1)
            temporal_classification({'semiphase': semiphase}, queryDict2, 'performance', C=0.1)
            temporal_classification({'semiphase': semiphase}, queryDict3, 'direction', C=0.1)

In [None]:
from mesostat.stat.classification import cross_temporal_decoding

In [None]:
classifier = get_classifier("LogL2", 100)
mat = cross_temporal_decoding(dataL, dataR, classifier, balancing=False)

In [None]:
nL = len(dataL)
nR = len(dataR)
accNaive = np.max([nL, nR]) / (nL + nR)

In [None]:
plt.figure()
plt.imshow(mat - accNaive, vmin=-0.5, vmax=0.5, cmap='jet')
plt.colorbar()
plt.show()

# 3. Classification Tables

In [None]:
# 4B
queryDict = {"direction": "L", "datatype": "raw"}
#queryDict = {"performance": "Correct", "datatype": "raw"}
settings = {"serial": True, "metricSettings": None}
#settings = {"serial" : True, "metricSettings" : {"metric" : num_non_zero_std}}

rez = plots_lib.table_binary_classification(dataDB, "Maintenance", "performance", "cumul_ord_coeff", "",
                                            queryDict, settings)

rez

In [None]:
# 4A
queryDict = {"performance": "Correct", "datatype": "raw"}
#queryDict = {"direction": "L", "datatype": "high"}
settings = {"serial": True, "metricSettings": {"lag" : 1, "max_lag" : 3}}
#settings = {"serial" : True, "metricSettings" : {"metric" : num_non_zero_std}}

plots_lib.table_binary_classification_bymouse(dataDB, "Maintenance", "direction", "temporal_basis", "rp",
                                              queryDict, settings)