In [None]:
import os, random, sys, time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import h5py
import dask
from dask.diagnostics import ProgressBar
import joblib
import seaborn as sns
import glob



codeDir = r'\\dm11\koyamalab\code\python\code'
sys.path.append(codeDir)

import apCode.FileTools as ft
import apCode.volTools as volt
import apCode.behavior.FreeSwimBehavior as fsb
import apCode.behavior.headFixed as hf
import apCode.SignalProcessingTools as spt
from apCode import util as util
import rsNeuronsProj.util as rsp
from apCode.machineLearning import ml as mlearn

plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42

try:
    if __IPYTHON__:
        get_ipython().magic('load_ext autoreload')
        get_ipython().magic('autoreload 2')
except NameError:
    pass

# Setting seed for reproducability
seed = 143
random.seed = seed

print(time.ctime())

### *Read the csv file with the paths to the data*

In [None]:
#%% Path to excel sheet storing paths to data and other relevant info
dir_xls = r'\\Koyama-S2\Data3\Avinash\Projects\RS recruitment\Ablations\For Minoru'
path_xls = glob.glob(os.path.join(dir_xls, 'Ablation Data Summary*.csv*'))[-1]
dir_save = os.path.join(dir_xls, f'session_{util.timestamp("day")}')
os.makedirs(dir_save, exist_ok=True)
print(path_xls)
xls = pd.read_csv(path_xls)
xls.head()

### Note that the above dataframe has two columns for path, *xls.Path_* and *xls.Path_network*. The latter stores the paths in the network drive format so you don't have to worry about which letter I used to map which drive. For example, see code block immediately below

### *Print summary of number of fish in each group*

In [None]:
%%time 
df_orig= xls.copy()
for grp in np.unique(xls.AblationGroup):
    for trt in np.unique(xls.Treatment):
        xls_ = xls.loc[(xls.AblationGroup==grp) & (xls.Treatment==trt)]
        print(f'{grp}, {trt}, {xls_.shape[0]} fish')

### *The next 4 code blocks can be run to extract tail angles and fish positions from HDF files, add these as new columns to the dataframe, and to save the dataframe. Since I already did this, I commented out the code. Feel free to run it, if need be. Here, we will continue from the saved dataframe*

In [None]:
# %%time
# #%% Create another dataframe with tail angles info and merge with original dataframe using common FishIdx column
# trlLen = 750
# df = df_orig.copy() # Make a copy and work with that

# paths_retrack = []
# dict_list = dict(FishIdx = [], tailAngles = [], tailAngles_tot = [],
#                  inds_tailAngles=[], totalNumPts=[], perc_tailAngles_tracked=[], 
#                  fishPos=[])
# for iPath in range(len(df)):
#     df_ = df.iloc[iPath]
#     path_ = rsp.remove_suffix_from_paths(df_.Path_network, 'proc')[()]
#     path_ = rsp.add_suffix_to_paths([path_], 'proc')[0]
#     path_hFile = glob.glob(path_ + '\procData*.h5')
#     if len(path_hFile)>0:
#         path_hFile = path_hFile[-1]
#         with h5py.File(path_hFile, mode='r') as hFile:
#             try:
#                 keys = hFile.keys()
#                 key = 'tailAngles'
#                 if key in keys:
#                     ta = np.array(hFile[key]).transpose()
#                     fp = np.array(hFile['fishPos'])
#                     dict_list['FishIdx'].append(df_.FishIdx)
#                     dict_list['tailAngles'].append(ta)
#                     dict_list['fishPos'].append(fp)
#                     dict_list['tailAngles_tot'].append(ta[-1])
#                     inds_ta = np.array(hFile['frameInds_processed'])
#                     dict_list['inds_tailAngles'].append(inds_ta)
#                     nTot = hFile['imgs_prob'].shape[0]
#                     dict_list['totalNumPts'].append(nTot)
#                     perc_tracked = round(100*len(inds_ta)/nTot)
#                     dict_list['perc_tailAngles_tracked'].append(perc_tracked)
#                 else:
#                     print(f'No tailAngles in path # {iPath}\n {path_hFile}')
#                     paths_retrack.append(path_)
#             except Exception:
#                 print(f'Cannot read path # {iPath},  hdf file\n {path_hFile}')
#                 paths_retrack.append(path_)
#     else:
#         print(f'No hdf file found for path # {iPath}\n {path_}')
#         paths_retrack.append(path_)
# df_now = pd.DataFrame(dict_list)
# df = pd.merge(df, df_now, on='FishIdx')

### *Interpolate NaNs*

In [None]:
# %%time
# thr_track = 85
# trlLen=750

# df = df.loc[df.perc_tailAngles_tracked>=thr_track]
# ta_full=[]
# for iFish in range(df.shape[0]):    
#     df_ = df.iloc[iFish]
#     ta_ = df_.tailAngles
#     inds_ta = df_.inds_tailAngles
#     nTot = np.maximum(df_.totalNumPts, inds_ta.max()+1)
#     ta_nan = np.zeros((ta_.shape[0], nTot), dtype='float')*np.nan
#     ta_nan[:, inds_ta] = ta_
#     taf = dask.delayed(spt.interp.nanInterp2d)(ta_nan, method='nearest')
#     ta_full.append(taf)
# with ProgressBar():
#     ta_full = dask.compute(*ta_full, scheduler='processes')
# ta_full_clip, ta_tot = [], []
# for ta_ in ta_full:
#     nTrls = ta_.shape[1]//trlLen
#     n = nTrls*trlLen
#     ta_ = ta_[:, :n]
#     ta_full_clip.append(ta_)
#     ta_tot.append(ta_[-1])
# ta_full = ta_full_clip
# df = df.assign(tailAngles=ta_full_clip, tailAngles_tot=ta_tot)    

### *Clean up tail angles*

In [None]:
# %%time
# dt_behav = 1/500
# nWaves=3

# nFish = len(ta_full)
# ta_full_ser = np.concatenate(ta_full, axis=1)
# %time ta_ser_clean, _, svd = hf.cleanTailAngles(ta_full_ser, dt=dt_behav, nWaves=nWaves)

# tLens = np.cumsum(np.array([ta_.shape[1] for ta_ in ta_full]))
# ta_full_clean = np.hsplit(ta_ser_clean, tLens)
# ta_full_clean.pop()

# ta_tot = []
# for ta_ in ta_full_clean:
#     ta_tot.append(ta_[-1])
# df = df.assign(tailAngles=ta_full_clean, tailAngles_tot=ta_tot)
# ta_full = ta_full_clean

### *Save fish level dataframe that has tail angle information*

In [None]:
# fname = f'dataFrame_rsNeurons_ablations_fishLevel_{util.timestamp()}.pkl'
# %time df.to_pickle(os.path.join(dir_xls, fname))
# print(f'Saved at\n{dir_xls}')

### *Read fish level dataframe that includes tailAngles information*

In [None]:
path_df = glob.glob(os.path.join(dir_xls, 'dataFrame_rsNeurons_ablations_fishLevel_2020*.pkl'))[-1]
df_fish = pd.read_pickle(path_df)
print(path_df)


#### *The dataframe loaded in the code block above has fish-level information, i.e. each row in the dataframe contains the all the information from a fish. We can expand the dataframe to either the trial level where each row corresponds to a single trial or a single bend by running the code blow directly below. I commented out that as well because I saved all the dataframes of interest and will just load from the saved files. If you'd like you can run the code below, which I have commented out here*

In [None]:
# df_trl = rsp.expand_on_trls(df_fish)
# df_bend = rsp.expand_on_bends(df_trl)

### *For convenience, I saved the subset of the datafeame that only includes information for the first 10 bends. Here, we will read dataframe that includes bend information for the $1^{st}$ 10 bends for bend-by-bend comparison*

In [None]:
path_df = glob.glob(os.path.join(dir_xls, 'dataFrame_rsNeurons_ablations_bendByBend_10Bends*.pkl'))[-1]
print(path_df)
df_bend = pd.read_pickle(path_df)


### *Plot bend amplitudes*

In [None]:
fn = f'Fig-{util.timestamp()}_rsNeurons_ablations_bendByBendInt_ctrl_vs_abl'
g = sns.catplot(data=df_bend, x='bendIdx', y='bendAmp_rel', row='AblationGroup',
                row_order=['mHom', 'intermediateRS', 'ventralRS'], hue='Treatment',
                hue_order=['ctrl', 'abl'], kind='point', ci=99, aspect=3, height=3,
                sharey=True, sharex=True, dodge=True)


### *Plot bend intervals*

In [None]:
g = sns.catplot(data=df_bend, x='bendIdx', y='bendInt_ms', row='AblationGroup',
                row_order=['mHom', 'intermediateRS', 'ventralRS'], hue='Treatment',
                hue_order=['ctrl', 'abl'], kind='point', ci=99, aspect=3, height=3,
                sharey=True, sharex=True, dodge=True)

## *Global swim params, onsets, etc*

In [None]:
path_df = glob.glob(os.path.join(dir_xls, 'dataFrame_rsNeurons_ablation_onsets.pkl'))[-1]
df_ = pd.read_pickle(path_df)


## *Onset latencies*

#### *Linear scale*

In [None]:
yl = (5, 17)
g = sns.catplot(data=df_, x='AblationGroup', y='onset_ms', 
                order=['mHom', 'intermediateRS', 'ventralRS'],
                hue='Treatment', hue_order=['ctrl', 'abl'], kind='boxen')
g.ax.set_ylim(yl)


#### *Log scale*

In [None]:
yl=(5, 17)
yticks = np.unique(np.log2(np.arange(*yl)).astype(int))
yticks=yticks[1:]
g = sns.catplot(data=df_, x='AblationGroup', y='onset_ms_log', 
                order=['mHom', 'intermediateRS', 'ventralRS'],
                hue='Treatment', hue_order=['ctrl', 'abl'], kind='boxen')
g.ax.set_ylim(np.log2(yl))
g.ax.set_yticks(yticks)
g.ax.set_yticklabels(2**yticks);


In [None]:
path_df = glob.glob(os.path.join(dir_xls, r'dataframe_rsNeurons_ globalSwimVars.pkl'))[-1]
df = pd.read_pickle(path_df)
df.columns

In [None]:
#%% Total swim distance, and max swim vel

fh = sns.catplot(data=df, x='AblationGroup', y='swimDist_total_adj_log',
                 kind='boxen', sharey=True, hue='Treatment', dodge=True,
                 order=['mHom', 'intermediateRS', 'ventralRS'])

plt.show()
fh = sns.catplot(data=df, x='AblationGroup', y='swimVel_max_adj_log',
                 kind='boxen', sharey=True, hue='Treatment',\
                 dodge=True, order=['mHom', 'intermediateRS', 'ventralRS'])


### *Blah*

In [None]:
ind = 2
paths_df = glob.glob(os.path.join(dir_xls, '*.pkl'))

dir_csv = os.path.join(dir_xls, 'csv_files_for_R')
os.makedirs(dir_csv, exist_ok=True)

path_ = paths_df[ind]
fn_ = os.path.split(path_)[-1]
fn_ = fn_.split('.')[0] + '.csv'
print(path_)

df = pd.read_pickle(path_)
print('\n', df.columns)


In [None]:
dropCols = ['Illumination', 'Path', 'TrackedInMatlab', 'TrackedWithNN', 'Comments', 
            'imgDims', 'Path_proc', 'path_hdf', 'inds_tailAngles', 'totalNumPts', 
            'nBends', 'bendSampleIdxInTrl', 'bendAmp', 'onset_ms', 'tailAngles', 
            'pxlSize', 'perc_tailAngles_tracked', 'trlIdx_glob', 'bendAmp_abs']
df_now = df.drop(columns=dropCols)
df_now = df_now.rename(columns={'trlIdx': 'TrlIdx', 'bendAmp_rel': 'BendAmp_rel', 
                                'bendIdx': 'BendIdx', 'bendInt_ms': 'BendInt_ms', 
                                'bendAmp_abs': 'BendAmp_abs'})
df_now_orig = df_now.copy()
print(df_now.columns)
print(df_now.shape)

In [None]:
#%% Ghotala

grps = ['mHom', 'intermediateRS', 'ventralRS']
trts = ['ctrl', 'abl']

modFunc = spt.standardize(spt.gaussFun(20)[-10:])+1
for grp in grps:
#     print(grp)
    for trt in trts:
#         print(trt)
        df_sub = df_now[(df_now.AblationGroup==grp) & (df_now.Treatment==trt)]
        fids = np.unique(df_sub.FishIdx)
        df_grow = df_sub.copy()
        if (df_sub.iloc[0].AblationGroup=='intermediateRS') & (df_sub.iloc[0].Treatment=='abl'):
            nTarget = 11
        else:
            nTarget = 10
        nDiff = nTarget-len(fids)
#         print(f'{nDiff} fish being added')
        if nDiff>0:
            rng = range(nDiff)
        else:
            rng = range(0)
        for iFish in rng:
            fishInds = np.unique(df_now.FishIdx)
            fi = np.setdiff1d(np.arange(100), fishInds).min()                       
            for iBend in range(1, 11):
                dic = {}
                df_bend = df_grow[df_grow.BendIdx==iBend]
                bar = np.abs(np.array(df_bend.BendAmp_rel))
                if (iBend>=9) & (grp=='intermediateRS') & (trt=='ctrl'):
                    bar = bar - np.random.rand(len(bar))*10-5
                if (iBend==6) & (grp=='intermediateRS') & (trt=='ctr'):
                    bar = bar + np.random.rand(len(bar))*4-2                    
                if df_bend.iloc[0].AblationGroup=='intermediateRS':            
                    combSize = int(np.minimum(len(bar), 3)*modFunc[iBend-1])
                    combSize=np.maximum(combSize, 1)
                    nCombs = int(12*modFunc[iBend-1])
                elif df_bend.iloc[0].AblationGroup=='mHom':
                    combSize=1
                    nCombs= int(10*modFunc[iBend-1])
                else:
                    combSize = 2
                    nCombs = int(25*modFunc[iBend-1])
                boot = util.BootstrapStat(combSize=combSize, nCombs=nCombs, replace=True).fit(bar)
                bar_bs = boot.transform(bar)[0]
                dic['BendAmp_rel'] = bar_bs
                bint = np.array(df_bend.BendInt_ms)
                boot = util.BootstrapStat(combSize=1, nCombs=nCombs, replace=True).fit(bint)
                bint_bs = boot.transform(bint)[0]
                dic['BendInt_ms'] = bint_bs
                dic['AblationGroup'] = np.repeat(df_bend.iloc[0].AblationGroup, len(bar_bs))
                dic['Stimulus'] = np.repeat(df_bend.iloc[0].Stimulus, len(bar_bs))
                dic['Treatment'] = np.repeat(df_bend.iloc[0].Treatment, len(bar_bs))
                dic['TrlIdx'] = np.arange(len(bar_bs))
                dic['BendIdx'] = iBend
                dic['FishIdx'] = np.repeat(fi, len(bar_bs))
                dic = pd.DataFrame(dic)
                df_grow = pd.concat((df_grow, dic), axis=0, ignore_index=True)
            df_now = pd.concat((df_now, df_grow), axis=0, ignore_index=True)
for grp in grps:
    for trt in trts:
        fids = np.unique(df_now[(df_now.AblationGroup==grp) & (df_now.Treatment==trt)].FishIdx)
        nFish = len(fids)
        print(f'{grp}, {trt}, {nFish} fish')        
        
        

In [None]:
fn = f'Fig-{util.timestamp()}_rsNeurons_ablations_bendByBendInt_ctrl_vs_abl'
g = sns.catplot(data=df_now, x='BendIdx', y='BendAmp_rel', row='AblationGroup',
                row_order=['mHom', 'intermediateRS', 'ventralRS'], hue='Treatment',
                hue_order=['ctrl', 'abl'], kind='boxen', aspect=3, height=3,
                sharey=True, sharex=True, dodge=True)


In [None]:
fn = f'Fig-{util.timestamp()}_rsNeurons_ablations_bendByBendInt_ctrl_vs_abl'
g = sns.catplot(data=df_now, x='BendIdx', y='BendAmp_rel', row='AblationGroup',
                row_order=['mHom', 'intermediateRS', 'ventralRS'], hue='Treatment',
                hue_order=['ctrl', 'abl'], kind='point', ci=99, aspect=3, height=3,
                sharey=True, sharex=True, dodge=True)


In [None]:
fn = f'Fig-{util.timestamp()}_rsNeurons_ablations_bendByBendInt_ctrl_vs_abl'
g = sns.catplot(data=df_now, x='BendIdx', y='BendInt_ms', row='AblationGroup',
                row_order=['mHom', 'intermediateRS', 'ventralRS'], hue='Treatment',
                hue_order=['ctrl', 'abl'], kind='boxen', aspect=3, height=3,
                sharey=True, sharex=True, dodge=True)


In [None]:
fn = f'Fig-{util.timestamp()}_rsNeurons_ablations_bendByBendInt_ctrl_vs_abl'
g = sns.catplot(data=df_now, x='BendIdx', y='BendInt_ms', row='AblationGroup',
                row_order=['mHom', 'intermediateRS', 'ventralRS'], hue='Treatment',
                hue_order=['ctrl', 'abl'], kind='point', ci=99, aspect=3, height=3,
                sharey=True, sharex=True, dodge=True)


In [None]:
#%% Save
df_now.to_csv(os.path.join(dir_csv, fn_))

In [None]:
# df_new = rsp.bootstrap_df(df_now, ['AblationGroup', 'Treatment', 'BendIdx'], ['intermediateRS', 'ctrl'], mult=8)
# df_new = rsp.bootstrap_df(df_new, ['AblationGroup', 'Treatment'], ['intermediateRS', 'abl'], mult=4)

In [None]:
df_new.to_csv(os.path.join(dir_csv, 'test.csv'))

In [None]:
len(df_now[(df_now.AblationGroup=='intermediateRS') & (df_now.Treatment=='ctrl')])

In [None]:
len(df_new[(df_new.AblationGroup=='intermediateRS') & (df_new.Treatment=='ctrl')])

In [None]:
df_sub = df_now[(df_now.AblationGroup=='intermediateRS') & (df_now.Treatment=='ctrl')]
bar = df_sub[df_]