In [None]:
import os, random, sys, warnings, time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import re
import dask
import dask.array as darr
import caiman as cm
import h5py
import tifffile as tff
import glob
from joblib import Parallel, delayed
from skimage.util import montage
import keras
from keras.callbacks import ModelCheckpoint


codeDir = r'\\dm11\koyamalab\code\python\code'
sys.path.append(codeDir)
import apCode.FileTools as ft
import apCode.volTools as volt
from apCode.machineLearning import ml as mlearn
from apCode.machineLearning.unet import model
import apCode.behavior.FreeSwimBehavior as fsb
import apCode.behavior.headFixed as hf
import apCode.SignalProcessingTools as spt
import apCode.geom as geom
import seaborn as sns
import importlib
from apCode import util as util
import apCode.ephys as ephys
from apCode import hdf
from apCode.imageAnalysis.spim import regress

plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42


try:
    if __IPYTHON__:
        get_ipython().magic('load_ext autoreload')
        get_ipython().magic('autoreload 2')
except NameError:
    pass

# Setting seed for reproducibility
seed = 143
random.seed = seed

print(time.ctime())

### *Let's first try the simpler version of the code where where are directly reading from a single fish path instead of reading from within the path stored in the excel file*

In [None]:
#%% Path to excel sheet storing paths to data and other relevant info
dir_fish = r'Y:\Avinash\Head-fixed tail free\GCaMP imaging\2019-12-31\f1'


### *Read peristimulus  $Ca^{2+}$ images, store in an HDF file and get path the the HDF file*

In [None]:
%time hFilePath = hf.read_and_store_ca_imgs(dir_fish)

### *Read the raw $Ca^{2+}$ images from the HDF file, register and store in the same HDF file*

In [None]:
%%time 
patchPerc = (60, ) # Default (40, )
patchOverlapPerc = (80, ) # Default (70, ) 
hFilePath = hf.register_piecewise_from_hdf(hFilePath, patchPerc=patchPerc, 
                                           patchOverlapPerc=patchOverlapPerc)[0]

## *Resume processing/analysis from here*

### *Start by seeing if the data in the HDF is accessible*

In [None]:
hFilePath = glob.glob(os.path.join(dir_fish, 'procData*.h5'))[-1] #Load the latest file
print(hFilePath)
with h5py.File(hFilePath, mode='r') as hFile:
    print(hFile.keys())


### *Load a slice from the raw and registered stacks and play as movie to check registration looks ok!*

In [None]:
%%time 
iSlc = 14 # Slice index (0-29; useless top slice discarded during processing)
fr=10 # Movie frame rate

with h5py.File(hFilePath, mode='r') as hFile:
    slc_raw = np.array(hFile['ca_raw'][iSlc])
    slc_reg = np.array(hFile['ca_reg'][iSlc])
a = spt.zscore(slc_raw[:,::2, ::2])
b = spt.zscore(slc_reg[:, ::2, ::2]) 
mov = cm.movie(np.concatenate((a, b), axis=1), fr=fr)

In [None]:
#%% Play the movie
mov.play(magnification=2.5)

### *Optionally, save registered slices as tif files for more careful examination with ImageJ*

In [None]:
%%time
#%% Save each of the registered slices in a separate .tif file
dir_save = os.path.join(dir_fish, f'registered_slices_cpwr_{util.timestamp()}')
if not os.path.exists(dir_save):
    os.mkdir(dir_save) 

with h5py.File(hFilePath, mode='r') as hFile:
    for iSlc, slc in enumerate(hFile['ca_reg']):
        slc_int = slc.astype('uint16')
        tff.imsave(os.path.join(dir_save, r'slice_{:03d}.tif'.format(iSlc)), slc_int)
print(f'Saved at\n{dir_save}')

### *If ROIs have not already been drawn, create temporally averaged $Ca^{2+}$ image volume on which ROIs can be drawn*

In [None]:
%%time
ca_reg_avg = []
with h5py.File(hFilePath, mode='r') as hFile:
    for z in range(hFile['ca_reg'].shape[0]):
        ca_reg_avg.append(np.array(hFile['ca_reg'][z]).mean(axis=0))
ca_reg_avg = np.array(ca_reg_avg).astype('uint16')
tff.imsave(os.path.join(dir_fish, 'averageCaImgVol.tif'), data=ca_reg_avg)
print(f'Saved at\n{dir_fish}')

### *After ROIs have been drawn on the time averaged slices, read the ROIs and exract timeseries*

In [None]:

hFilePath= glob.glob(dir_fish + '/procData*.h5')[-1]
dir_rois = glob.glob(dir_fish + '/RoiSet*.zip')[-1]

with h5py.File(hFilePath, mode='r') as hFile:
    stackDims = hFile['ca_reg'].shape
imgDims = stackDims[-2:]
volDims = (stackDims[0], *imgDims)

rois = mlearn.readImageJRois(dir_rois, imgDims, multiLevel=False)[1]
masks, roiNames = hf.consolidate_rois(rois, volDims)
masks.shape

### *A quick glance at z-projected ROIs in the dataset*

In [None]:
plt.style.use(('fivethirtyeight', 'seaborn-talk', 'seaborn-ticks'))
plt.figure(figsize=(20, 10))
plt.imshow(np.zeros(masks.shape[-2:]), cmap='gray')
for iMask, mask in enumerate(masks):
    img = mask.max(axis=0)
    inds = np.where(img)    
    plt.scatter(inds[1], inds[0], c=np.array(plt.cm.tab20(iMask))[None,], alpha=0.2)
    plt.scatter(inds[1][0], inds[0][0], c=np.array(plt.cm.tab20(iMask))[None,], label=f'{roiNames[iMask]}')
plt.grid('')
leg =plt.legend(fontsize=24, ncol=2, framealpha=0)
for txt in leg.get_texts():
    plt.setp(txt, color='w')

### *Extract timeseries for ROIs*

In [None]:
%%time 
key='/ca_reg'

with h5py.File(hFilePath, mode='r') as hFile:
    ca_reg = np.array(hFile[key])
    trlLen = np.array(hFile['nImgsInTrl_ca'])[0]
    trlIdx = np.array(hFile['trlIdx_ca'])
    stimLoc = util.to_utf(np.array(hFile['stimLoc']))
    sessionIdx = np.array(hFile['sessionIdx'])
ca_reg = ca_reg.swapaxes(0, 1)    
arr = darr.from_array(ca_reg)


In [None]:
%%time
def func_now(imgs, mask):
    nPxls = mask[np.where(mask==1)].sum()
    ts = np.apply_over_axes(np.sum, imgs*mask[None, ...], [1, 2, 3]).flatten()
    ts = ts/nPxls
    return ts

roi_ts = []
print('Extracting roi timeseries...')
for iMask, mask in enumerate(masks):
    print(f'{roiNames[iMask]}')
    nPxls = mask[np.where(mask==1)].sum()
    ts = arr*mask[None,...]
    ts = ts.sum(axis=-1).sum(axis=-1).sum(axis=-1)/nPxls
    roi_ts.append(ts.compute())
roi_ts = np.array(roi_ts)



### *Put data in a dataframe and save*

In [None]:
fn = 'dataframe_roi_ts.pkl'
roi_ts_trl = roi_ts.reshape(len(roiNames), -1, trlLen)
trlIdx_trl = trlIdx.reshape(-1, trlLen)[:, 0]
stimLoc_trl = stimLoc.reshape(-1, trlLen)[:, 0]
sessionIdx_trl=sessionIdx.reshape(-1, trlLen)[:, 0]
nTrls = len(stimLoc_trl)
nRois = len(roiNames)
df = {}
df['roiName'] = np.repeat(roiNames, nTrls)
df['trlIdx'] = np.tile(trlIdx_trl, nRois)
df['sessionIdx'] = np.tile(sessionIdx_trl, nRois)
df['stimLoc'] = np.tile(stimLoc_trl, nRois)
df['ts'] = list(roi_ts_trl.reshape(-1, trlLen))
df = pd.DataFrame(df)

df.to_pickle(os.path.join(dir_fish, fn))

### *Plot trial averaged $Ca^{2+}$ responses for head and tail stimulation trials for all ROIs*

In [None]:
nCols=3
Fs_ca = 2 # Frame rate
nPre=3
nRows = int(np.ceil(nRois/nCols))

fh, ax = plt.subplots(nRows, nCols, figsize=(20, 20*nRows//nCols), 
                      sharex=True, sharey=False)
ax = ax.flatten()
t = (np.arange(trlLen)-3)*(1/Fs_ca)
for iRoi, rn in enumerate(roiNames):
    df_sub = df.loc[df.roiName==rn]
    ts_head = np.array([np.array(_) for _ in df_sub.loc[df_sub.stimLoc=='h'].ts])
    ts_head = ts_head-ts_head[:, :nPre].mean(axis=1)[:, None]
    ts_tail = np.array([np.array(_) for _ in df_sub.loc[df_sub.stimLoc=='t'].ts])
    ts_tail = ts_tail-ts_tail[:, :nPre].mean(axis=1)[:, None]
    boot_head = util.BootstrapStat(combSize=ts_head.shape[0], nCombs=1000, replace=True)
    boot_tail = util.BootstrapStat(combSize=ts_head.shape[0], nCombs=1000, replace=True)
    mu_head = ts_head.mean(axis=0)
    ci_head = 2*np.std(boot_head.fit_transform(ts_head)[0], axis=0)#/(ts_head.shape[0]**0.5)
    mu_tail = ts_tail.mean(axis=0)
    ci_tail = 2*np.std(boot_head.fit_transform(ts_tail)[0], axis=0)#/(ts_tail.shape[0]**0.5)    
    if iRoi==0:
        ax[iRoi].fill_between(t, mu_head+ci_head, mu_head-ci_head, 
                              color=plt.cm.tab10(0), alpha=0.5, label='Head')
        ax[iRoi].fill_between(t, mu_tail+ci_tail, mu_tail-ci_tail, 
                              color=plt.cm.tab10(1), alpha=0.5, label='Tail')
        ax[iRoi].legend()
    else:
        ax[iRoi].fill_between(t, mu_head+ci_head, mu_head-ci_head, 
                              color=plt.cm.tab10(0), alpha=0.5)
        ax[iRoi].fill_between(t, mu_tail+ci_tail, mu_tail-ci_tail, 
                              color=plt.cm.tab10(1), alpha=0.5)
    ax[iRoi].set_title(rn, fontsize=14)
    ax[iRoi].set_xlim(t[0], t[-1])
ax[iRoi].set_xlabel('Time (s)')
ax[iRoi].set_ylabel('Raw intensity')
fh.suptitle('Trial averaged Ca2+ responses for head and tail stimulation trials', fontsize=18)
plt.subplots_adjust(top=0.95)        
    
    

## *Onto behavior tracking*

### *For the sake of showing how it's done, I will actually create a U-net model afresh and train on some data before segmenting fish in this dataset*

In [None]:
imgDims = (256, 256) # Size of images to be trained on. Will rescale input images to this size before training
loss = model.focal_loss # I've recently found that this loss function works better than the custom one
optimizer='adam' # ('rmsprop' or 'adam')

# Now load a naive U-net model object (pre-compiled)
unet = model.get_unet(img_height=imgDims[0], img_width=imgDims[1], img_channels=1, 
                      loss=loss, optimizer=optimizer)

### *If you want to use a pre-trained net, run the cell below instead of the one above*

In [None]:
path_unet = glob.glob(dir_fish + '/trainedU_headFixed_*.h5')[-1]
unet = mlearn.loadPreTrainedUnet(path_unet)

### *I have an excel sheet with paths to training images and corresponding masks. It has training data for different imaging conditions so we will read images from only the paths corresponding to the head fixed condition*

In [None]:
dir_xls = r'\\Koyama-S2\Data3\Avinash\Ablations and Behavior'
path_xls = glob.glob(os.path.join(dir_xls, 'Paths_to*.xlsx'))[-1]
xls_train = pd.read_excel(path_xls, sheet_name='Uncropped')
xls_train = xls_train.loc[xls_train.exptType=='headFixed']

changePath = lambda p: r'\\Koyama-S2\Data3' + p.split(':')[-1]
path_imgs = list(map(changePath, np.array(xls_train.pathToImages)))
path_masks = list(map(changePath, np.array(xls_train.pathToMasks)))
imgDims = unet.input_shape[1:3]
imgs_train, masks_train = mlearn.read_training_images_and_masks(np.array(path_imgs), 
                                                    np.array(path_masks), imgDims=imgDims)
masks_train = (masks_train>0).astype(int)
print(f'Training on {imgs_train.shape[0]} imgs of dimensions {imgs_train.shape[1:]}')

### *Evaluate performance as a first check. If naive model, then of course score will be low, and if not then a score > 0.85 is usually good*

In [None]:
metrics = unet.evaluate(imgs_train[..., None], masks_train[..., None], batch_size=32, verbose=1)
print(np.c_[unet.metrics_names, metrics])


### *Run the code cell below a few times to randomly select and plot images and masks for checking purposes* 

In [None]:
ind= np.random.choice(range(imgs_train.shape[0]), size=1)[0]
m = montage((imgs_train[ind], masks_train[ind]), rescale_intensity=True, grid_shape=(1,2))
plt.figure(figsize=(20, 20))
plt.imshow(m, cmap='viridis')


### *We will create a checkpointer callback that monitors performance during the training epochs and automatically store the weights of the best model to a file from which we can load weights for future use*

In [None]:
#%% Checkpointer callback for storing best weights
fp = os.path.join(dir_fish, f'best_weights_headFixed_{util.timestamp()}.hdf')
checkpointer = ModelCheckpoint(filepath=fp, monitor='val_dice_coef', verbose=1,\
                               save_best_only=True, mode='max', save_weights_only=True)

keras_callbacks = [checkpointer]

### *Augment the training image set to include more diverse and challenging training images that can result in more robust training*

In [None]:
#%% Augment before training
upSample = 5 # This will expand the training set by this much
aug_set=('rn', 'sig', 'log', 'inv', 'heq', 'rot', 'rs')
%time imgs_aug, masks_aug, augs = mlearn.augmentImageData(imgs_train, masks_train,\
                                                          upsample=upSample, aug_set=aug_set)

imgs_aug = mlearn.prepare_imgs_for_unet(imgs_aug, unet)
masks_aug = mlearn.prepare_imgs_for_unet(masks_aug, unet)
masks_aug= (masks_aug>0).astype(int)
print(f'Augmentation: {len(imgs_train)} --> {len(imgs_aug)}')

### *Run the cell below a few times to check the augmented images*

In [None]:
ind= np.random.choice(range(imgs_train.shape[0]), size=1)[0]
m = montage((imgs_aug[ind][..., 0], masks_aug[ind][..., 0]), rescale_intensity=True, grid_shape=(1,2))
plt.figure(figsize=(20, 20))
plt.imshow(m, cmap='viridis')
plt.title(augs[ind].upper(), fontsize=30);


### *Now train! Make sure you have tensorflow-gpu installed in your environment. If not, then training can be rather slow. One way to know that the process is running on GPU is to look at CPU usage, which should not exceed 10%*

In [None]:
%%time
batch_size = 32 # Larger batch sizes are usually better, but reduce if you get an OOM error
epochs = 250 # Number of training epochs
validation_split = 0.1 # Fraction of images from the training set to be used for validation
initial_epoch = 190 # 0, if training a naive model. Can be used to retrain from a previous epoch onwards

his = unet.fit(imgs_aug, masks_aug, epochs=epochs, batch_size=batch_size,\
               validation_split=validation_split, callbacks=keras_callbacks, verbose=0,
               initial_epoch=initial_epoch)


### *Plot training metrics*

In [None]:
his = unet.history.history
print(his.keys())
plt.style.use(('seaborn-poster', 'seaborn-white'))
plt.figure(figsize=(15, 6))
plt.subplot(121)
plt.plot(his['val_dice_coef'],'.', label='validation set')
plt.plot(his['dice_coef'], label='training set')
plt.legend(fontsize=12)
plt.title('Dice coefficient', fontsize=14)

plt.subplot(122)
plt.plot(his['val_loss'],'.', label ='validation set')
plt.plot(his['loss'], label = 'training set')
plt.legend(fontsize=12)
plt.title('Loss', fontsize=14);

### *Load the best weights from the saved file and then save the U-net*

In [None]:
#%% Load the best weights and save
path_weights = glob.glob(os.path.join(dir_fish, 'best_weights_headFixed*.hdf'))[-1]
unet.load_weights(path_weights)

#%% Save the U-net
fn = f'trainedU_headFixed_{unet.input_shape[1]}x{unet.input_shape[2]}_{util.timestamp()}.h5'
unet.save(os.path.join(dir_fish, fn))
print(time.ctime())

### *Evaluate performance on both training and validation set*

In [None]:
metrics = unet.evaluate(imgs_aug, masks_aug, batch_size=32, verbose=1)
print(np.c_[unet.metrics_names, metrics])


### *Walk through the directory tree and get all the subdirectories with behavior images*

In [None]:
%%time
#%% Get the paths to all the behavior directories
roots, dirs, files = zip(*[out for out in os.walk(dir_fish)])
inds = util.findStrInList('Autosave', roots)
dirs_behav = np.array(roots)[inds]

### *Check segmentation on a consecutive set of images*

In [None]:
iTrl = 10
frameRange = (200, 1000)

imgNames = ft.findAndSortFilesInDir(dirs_behav[iTrl], ext='bmp')[range(*frameRange)]

imgs = volt.img.readImagesInDir(dirs_behav[iTrl], imgNames=imgNames)
imgs_rs = volt.img.resize(imgs, unet.input_shape[1:3])

### *Predict on loaded images*

In [None]:
imgs_prob = np.squeeze(unet.predict(imgs_rs[..., None], batch_size=32, verbose=1))

### *Make a movie* 

In [None]:
alpha = 0.2
merge_ch = 0
fps = 50

from skimage.color import gray2rgb

imgs_prob_255 = (imgs_prob*255).astype(int)
imgs_rs_rgb = np.array([gray2rgb(_, alpha=0.5) for _ in imgs_rs])

imgs_rs_rgb[..., merge_ch] = (alpha*imgs_rs_rgb[..., merge_ch] + (1-alpha)*imgs_prob_255).astype(int) 
ani =volt.animate_images(imgs_rs_rgb, fps=fps, fig_size=(15, 15))
ani

### *If movie looks bad, then select a few images where tracking failed and re-train the U-net, or else segment fish in all behavior images, compute tail angles, and write to the saved HDF file*

In [None]:
%time hFilePath = hf.extractAndStoreBehaviorData_singleFish(dir_fish, uNet=unet)


## *Resume from here*

### *Read tail angles from HDF file, clean with svd and wavelet, and trialize*

In [None]:
hFilePath = glob.glob(os.path.join(dir_fish, 'procData*.h5'))[-1]
path_df = glob.glob(os.path.join(dir_fish, 'dataframe_roi_ts*.pkl'))[-1]
df = pd.read_pickle(path_df)
with h5py.File(hFilePath, mode='r') as hFile:
    stimLoc_behav = util.to_utf(np.array(hFile['behav/stimLoc']))
    ta = np.array(hFile['behav/tailAngles'])
nTrls = ta.shape[0]//50
ta_ser = np.concatenate(np.vsplit(ta, nTrls), axis=1)
%time ta_clean, _, svd = hf.cleanTailAngles(ta_ser, dt=1/500)
ta_trl = np.array(np.hsplit(ta_clean, nTrls))
df_orig = df.copy()

### *Match $Ca^{2+}$ and behavior trials, put them all in on dataframe and save* 

In [None]:
df = df_orig.copy()
sessionIdx, stimLoc = zip(*[(int(trl[:3]), trl[-1]) for trl in stimLoc_behav])
sessionIdx, stimLoc = np.array(sessionIdx), np.array(stimLoc)
if sessionIdx.min()==1:
    sessionIdx = sessionIdx-1
trlIdx =[]
for sid in np.unique(sessionIdx):
    n = len(np.where(sessionIdx==sid)[0])
    trlIdx.extend(np.arange(n))
trlIdx = np.array(trlIdx)
df_ = dict(sessionIdx=sessionIdx, stimLoc=stimLoc, trlIdx=trlIdx, tailAngles=list(ta_trl))
df_ = pd.DataFrame(df_)
df = pd.merge(df, df_, on = ['stimLoc', 'sessionIdx', 'trlIdx'])
df.to_pickle(os.path.join(dir_fish, f'dataframe_roi_ta_{util.timestamp()}.pkl'))
print(f'Saved dataframe at\n{dir_fish}')

In [None]:
iTrl=44
plt.figure(figsize=(20, 5))
plt.plot(ta_trl[iTrl][-1])
plt.xlim(300, 1500)
plt.axvline(500, ls='--', c='r')