In [None]:
# Standard libraries
import json
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from ipywidgets import interactive, IntProgress
from IPython.display import display

# Append base directory
import os,sys,inspect
rootname = "pub-2020-exploratory-analysis"
thispath = os.getcwd()
# thispath = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
rootpath = os.path.join(thispath[:thispath.index(rootname)], rootname)
sys.path.append(rootpath)
print("Appended root directory", rootpath)

from mesostat.utils.qt_helper import gui_fnames, gui_fpath
from lib.gallerosalas.data_fc_db_raw import DataFCDatabase
from mesostat.metric.metric import MetricCalculator
from mesostat.utils.signals.filter import zscore

%load_ext autoreload
%autoreload 2

In [None]:
params = {}
# params['root_path_data']  = gui_fpath("Path to data collection",  './')
params['root_path_data'] = '/media/alyosha/Data/TE_data/yasirdata_raw/'
# params['root_path_data'] = '/home/alyosha/data/yasirdata_raw/'

In [None]:
dataDB = DataFCDatabase(params)

In [None]:
print('mice', dataDB.mice)
print('nSessions', len(dataDB.sessions))
print('datatypes', dataDB.get_data_types())
print('nChannel', dataDB.get_nchannels('mou_5'))

In [None]:
dataDB.calc_shortest_distances()

In [None]:
fig, ax = plt.subplots(ncols=2)
ax[0].imshow(dataDB.allenMap)
ax[1].imshow(dataDB.allenDist)
plt.show()

len(dataDB.allenCounts)

In [None]:
from mesostat.visualization.mpl_barplot import barplot_stacked

def get_trial_distribution(dataDB):
    rezLst = []
    
    for mousename in sorted(dataDB.mice):
        for trialType in dataDB.get_trial_type_names():
            nTrial = 0
            for session in dataDB.get_sessions(mousename):
                trialTypes = dataDB.get_trial_types(session, mousename)
                nTrial += (trialTypes == trialType).sum()
            rezLst += [[mousename, trialType, nTrial]]
    
    df = pd.DataFrame(rezLst, columns=['mousename', 'trialType', 'nTrial'])
    df["nTrial"] = pd.to_numeric(df["nTrial"])
    
    fig, ax=plt.subplots()
    barplot_stacked(ax, df, 'mousename', 'nTrial')
    plt.show()
    
    
    
get_trial_distribution(dataDB)

In [None]:
def get_summary():
    names = []
    counts = []
    miceLst = []
    sessionsLst = []

    for mousename in dataDB.mice:
        for session in dataDB.sessions[mousename]:
            miceLst += [mousename]
            sessionsLst += [session]
            
            trialTypes = dataDB.get_trial_types(session, mousename)
            n, c = np.unique(trialTypes, return_counts=True)
            names += [n]
            counts += [c]
            
    unqTypes = sorted(set(np.hstack(names)))
    
    rez = []
    for n,c in zip(names, counts):
        ncdict = dict(zip(n,c))
        rez += [[ncdict[t] if t in ncdict else 0 for t in unqTypes]]

    df = pd.DataFrame(rez, columns=unqTypes)
    df['mouse'] = miceLst
    df['session'] = sessionsLst
    return df

In [None]:
dataDB.get_trial_types('2017_03_06_session01', 'mou_5')

In [None]:
df = get_summary()

In [None]:
df

In [None]:
for iMouse, mousename in enumerate(dataDB.mice):
    sessions = dataDB.get_sessions(mousename)
    nRows = len(sessions)
    nCols = len(dataDB.dataTypes)
    fig, ax = plt.subplots(nrows=nRows, ncols=nCols, figsize=(4*nCols, 4*nRows), tight_layout=True)
#     fig.suptitle(mousename)
    
    for iSession, session in enumerate(sessions):
        print(mousename, session)
        
        ax[iSession, 0].set_ylabel(session)
        
        for iDataType, datatype in enumerate(dataDB.dataTypes):
            if iSession == 0:
                ax[0, iDataType].set_title(datatype)
                
#             try:
            dataSession = dataDB.get_neuro_data({'session' : session}, datatype=datatype, trialType='Hit')[0]
            nTrial, nTime, nChannel = dataSession.shape
            times = dataDB.get_times(nTime)

            ax[iSession, iDataType].plot(times, np.mean(dataSession, axis=0))
            dataDB.label_plot_timestamps(ax[iSession, iDataType], mousename, session, linecolor='r', textcolor='pink')
#             except:
#                 print('Something went wrong')
        
    plt.show()