In [None]:
#%% Import relevant code
import os, sys, time
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import re as regex
import dask
import h5py
import pandas as pd
import keras
from keras.callbacks import ModelCheckpoint
from skimage.util import montage
import glob

#--- Import my code
codeDir = r'\\dm11\koyamalab/code/python/code'
sys.path.append(codeDir)
# import apCode.FileTools as ft
import apCode.volTools as volt
from apCode.machineLearning import ml as mlearn
import apCode.SignalProcessingTools as spt
from apCode.machineLearning.unet import model
from apCode.behavior import FreeSwimBehavior as fsb
# from apCode import geom
import apCode.hdf as hdf
from apCode import util
from rsNeuronsProj import util as rsp
import apCode.behavior.headFixed as hf

#--- Setting seed for reproducability
seed = 143
np.random.seed = seed

plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42


#--- Auto-reload modules
try:
    if __IPYTHON__:
        get_ipython().magic('load_ext autoreload')
        get_ipython().magic('autoreload 2')
except NameError:
    pass

print(time.ctime())


## *Load a pre-trained U-net*

In [None]:
dir_unet = r'\\Koyama-S2\Data3\Avinash\U net'
path_unet = glob.glob(os.path.join(dir_unet, 'trainedU_headFixed*.h5'))[-1]

unet = mlearn.loadPreTrainedUnet(path_unet)
# unet = model.get_unet(img_height=256, img_width=256, img_channels=1)

if isinstance(unet.loss, str):
    print(unet.loss)
else:
    print(unet.loss.__name__)

In [None]:
dir_xls = r'\\Koyama-S2\Data3\Avinash\U net'
path_xls = glob.glob(os.path.join(dir_xls, 'Paths_to*Minoru*.xlsx'))[-1]
xls_train = pd.read_excel(path_xls, sheet_name='Uncropped')
xls_train = xls_train.loc[xls_train.exptType=='headFixed']
print(path_xls)

# changePath = lambda p: r'\\Koyama-S2\\Data3' + p.split(':')[-1] if p[1]==r":"
def changePath(p):
    if p[1] == r":":
        p = r'\\Koyama-S2\\Data3' + p.split(':')[-1]
    return p
        
paths_imgs = list(map(changePath, np.array(xls_train.pathToImages)))
path_masks = list(map(changePath, np.array(xls_train.pathToMasks)))
imgDims = unet.input_shape[1:3]
# imgs_train, masks_train = mlearn.read_training_images_and_masks(np.array(xls_train.pathToImages), 
#                                                     np.array(xls_train.pathToMasks), imgDims=imgDims)
imgs_train, masks_train = mlearn.read_training_images_and_masks(paths_imgs, 
                                                    path_masks, imgDims=imgDims)
masks_train = (masks_train>0).astype(int)
print(f'Training on {imgs_train.shape[0]} imgs of dimensions {imgs_train.shape[1:]}')


In [None]:
metrics = unet.evaluate(imgs_train[..., None], masks_train[..., None], batch_size=32, verbose=1)
print(np.c_[unet.metrics_names, metrics])

In [None]:
#%% Checkpointer callback for storing best weights
fp = os.path.join(dir_unet, f'best_weights_headFixed_{util.timestamp()}.hdf')
checkpointer = ModelCheckpoint(filepath=fp, monitor='val_dice_coef', verbose=1,\
                               save_best_only=True, mode='max', save_weights_only=True)

keras_callbacks = [checkpointer]

In [None]:
#%% Augment before training
upSample = 5 # This will expand the training set by this much
aug_set=('rn', 'sig', 'log', 'inv', 'heq', 'rot', 'rs')
%time imgs_aug, masks_aug, augs = mlearn.augmentImageData(imgs_train, masks_train,\
                                                          upsample=upSample, aug_set=aug_set)

imgs_aug = mlearn.prepare_imgs_for_unet(imgs_aug, unet)
masks_aug = mlearn.prepare_imgs_for_unet(masks_aug, unet)
masks_aug= (masks_aug>0).astype(int)
print(f'Augmentation: {len(imgs_train)} --> {len(imgs_aug)}')

In [None]:
metrics = unet.evaluate(imgs_aug, masks_aug, batch_size=32, verbose=1)
print(np.c_[unet.metrics_names, metrics])


In [None]:
# %%time
# batch_size = 32 # Larger batch sizes are usually better, but reduce if you get an OOM error
# epochs = 150 # Number of training epochs
# validation_split = 0.1 # Fraction of images from the training set to be used for validation

# his = unet.fit(imgs_aug, masks_aug, epochs=epochs, batch_size=batch_size,\
#                validation_split=validation_split, callbacks=keras_callbacks, 
#                verbose=0)


### *Load best weights and save U net*

In [None]:
inputDims = unet.input_shape[1:3]
if isinstance(unet.loss, str):
    lf = unet.loss
else:
    lf = unet.loss.__name__
fn = f'trainedU_headFixed_{inputDims[0]}x{inputDims[1]}_{lf}_{util.timestamp()}.h5'

path_wts = glob.glob(os.path.join(dir_unet, 'best_weights_headFixed*.hdf'))[-1]
unet.load_weights(path_wts)

%time unet.save(os.path.join(dir_unet, fn))
print(fn)



In [None]:
metrics = unet.evaluate(imgs_aug, masks_aug, batch_size=32, verbose=1)
print(np.c_[unet.metrics_names, metrics])


### *Read xls with paths to data*

In [None]:
dir_xls = r'\\Koyama-S2\Data3\Avinash\Projects\RS recruitment\GCaMP imaging'
dir_group = r'\\Koyama-S2\Data3\Avinash\Projects\RS recruitment\GCaMP imaging\Group'

file_xls = 'GCaMP volumetric imaging summary.xlsx'
xls = pd.read_excel(os.path.join(dir_xls, file_xls), sheet_name='Sheet1')
xls.head()

In [None]:
inds_fish = np.array(xls.FishIdx.dropna())
pathList = np.array([xls.loc[xls.FishIdx==ind].Path.iloc[0].replace('Y:','\\\\Koyama-S2\\Data3') for ind in inds_fish])


### *Create an SVD object for cleaning tail angles using tail angles from multiple fish and then save this model for future use. Later on, where more tail angles data becomes available, create a more comprehensive SVD object*

In [None]:
pathInds = range(4)
dic = dict(fishIdx=[], ta=[])
for iPath, path_ in enumerate(pathList[pathInds]):
    hfp = glob.glob(os.path.join(path_, 'procData*.h5'))[-1]
    with h5py.File(hfp, mode='r') as hFile:
        print(f'{iPath+1}/{len(pathInds)} \n{hfp}')
        ta_ = np.array(hFile['behav/tailAngles'])
        nTrls = ta_.shape[0]//50
        ta_trl = np.vsplit(ta_, nTrls)
        ta_ = np.concatenate(ta_trl, axis=1)
        fidx = np.repeat(iPath,  ta_.shape[1])
        dic['fishIdx'].append(fidx)
        dic['ta'].append(ta_)
dic['fishIdx'] = np.concatenate(dic['fishIdx'], axis=0)
dic['ta'] = np.concatenate(dic['ta'], axis=1)

In [None]:
%time ta_clean, _, svd = hf.cleanTailAngles(dic['ta'], dt=1/500, nWaves=5)

In [None]:
hf?

In [None]:
%matplotlib auto
plt.figure(figsize=(20, 5))

t = np.arange(ta_clean.shape[1])*(1/500)
plt.plot(t, dic['ta'][-1], lw=1, alpha=0.5)
plt.plot(t, ta_clean[-1], lw=1)
plt.plot(t, -100*dic['fishIdx'], lw=1, ls='--')
plt.xlim(0, t[-1])
# plt.xlim(0, 100)

In [None]:
pathList[2].replace("\\", "/")

In [None]:
iPath = 0
iTrl = 0
path_ = pathList[iPath]
hfp = glob.glob(os.path.join(path_, 'procData*.h5'))[-1]
print(hfp)
with h5py.File(hfp, mode='r') as hFile:
    print(hFile['behav'].keys())
    nTrls = hFile['behav']['tailAngles'].shape[0]//50
    print(hFile['behav/tailAngles'].shape)
    trlLen = hFile['behav/images_prob'].shape[0]//nTrls
    print(f'{nTrls} trls of length {trlLen}')
    trlInds = np.arange(trlLen*iTrl, trlLen*(iTrl+1))
    imgs = np.array(hFile['behav/images_prob'][trlInds])
    ta = np.array(hFile['behav/tailAngles'])
ta_trl = np.array(np.vsplit(ta, ta.shape[0]//50))
ta = np.concatenate(ta_trl, axis=1)


In [None]:
ta_clean = hf.cleanTailAngles(ta)[0]

In [None]:
plt.style.use(('seaborn-white', 'seaborn-talk', 'seaborn-ticks'))
plt.figure(figsize=(20, 5))
t = np.arange(ta.shape[1])*(1/500)
# plt.plot(t, ta[-1])
plt.plot(t, ta_clean[-1])
# plt.xlim(0, t[-1])
plt.xlim(110, 116)
plt.ylim(-50, 100)

In [None]:

ani = volt.animate_images(imgs[0:250])
ani

In [None]:
path_imgs = r'\\Koyama-S2\Data3\Avinash\Head-fixed tail free\GCaMP imaging\2019-11-06\session2\f1_alx-gal4_xa316_uas-gamp6s\002_t\behav\Autosave0_[00-11-1c-f1-75-10]_20191107_123723_AM'
savePath = r'\\Koyama-S2\Data3\Avinash\Head-fixed tail free\GCaMP imaging\2019-11-06\session2\f1_alx-gal4_xa316_uas-gamp6s'
foo = fsb.copy_images_for_training(path_imgs, nImgsToCopy=3, savePath=savePath)
# a = np.arange(10)
# import dask.array as darr

In [None]:
for iPath, path_ in enumerate(pathList):
    hfp = glob.glob(os.path.join(path_, 'procData*.h5'))[-1]
    print(f'Path # {iPath+1}/{len(pathList)}')
    %time hfp_ = hf.extractAndStoreBehaviorData_singleFish(path_, uNet=unet, hFilePath=hfp)
   

In [None]:
%%time
iTrl=0
hfp = glob.glob(os.path.join(pathList[1], 'procData*.h5'))[-1]
with h5py.File(hfp, mode='r') as hFile:
    nTrls = hFile['behav/tailAngles'].shape[0]//50
    ta = np.array(hFile['behav/tailAngles'])
    trlLen = ta.shape[1]
    trlInds = np.arange(iTrl*trlLen, (iTrl+1)*trlLen)
    imgs = hFile['behav/images_prob'][trlInds]
ta = np.array(np.vsplit(ta, nTrls))
ta = np.concatenate(ta, axis=1)
print(ta.shape)
ta_clean = hf.cleanTailAngles(ta, dt=1/500)[0]

In [None]:
plt.style.use(('seaborn-talk', 'seaborn-white', 'seaborn-ticks'))
plt.figure(figsize=(20, 5))
t = np.arange(ta.shape[1])*(1/500)
plt.plot(t, ta[-1])
plt.plot(t, ta_clean[-1])
plt.xlim(60, 80)
# plt.imshow(imgs[54])

In [None]:
for iPath, path_ in enumerate(pathList):
    track=False
    hfp = glob.glob(os.path.join(path_, 'procData*.h5'))
    if len(hfp)>0:
        hfp = hfp[-1]
        with h5py.File(hfp, mode='r') as hFile:
            if not 'behav' in hFile:
                track=True
                print(f'Path # {iPath+1}/{len(pathList)}')
        if track:
            %time hfp_ = hf.read_and_store_ca_imgs(path_)
            %time hfp_ = hf.extractAndStoreBehaviorData_singleFish(path_, uNet=unet)
    else:
        print(f'Path # {iPath+1}/{len(pathList)}')
        %time hfp_ = hf.read_and_store_ca_imgs(path_)
        %time hfp_ = hf.extractAndStoreBehaviorData_singleFish(path_, uNet=unet)

In [None]:
hf.register_piecewise_from_hdf?

In [None]:
%time dic_ta = hf.tailAngles_from_hdf_concatenated_by_trials(pathList)
ta = np.concatenate(dic_ta['tailAngles'], axis = 1)


In [None]:
%time ta_clean, _, svd = hf.cleanTailAngles(ta)

In [None]:
# plt.figure(figsize=(20, 5))
# t = np.arange(ta_clean.shape[1])*(1/500)
# plt.plot(t, ta[-1])
# plt.xlim(20, 50)
# plt.ylim(20, 25)

## *Read dataframe with all relevant information (paths, etc)*

In [None]:
#%% Path to excel sheet storing paths to data and other relevant info
dir_df = r'Y:\Avinash\Projects\RS recruitment\Ablations\session_20200422-00'
path_df = glob.glob(os.path.join(dir_df, 'dataFrame_rsNeurons_ablations_svdClean_2020*.pkl'))[-1]

df = pd.read_pickle(path_df)
dir_save = os.path.join(dir_df, f'session_{util.timestamp()}')
os.makedirs(dir_save, exist_ok=True)

print(df.columns)

## *Evaluate pre-training performance*

In [None]:
ta_all = [np.array(ta_) for ta_ in df['tailAngles']]
ta_all = np.concatenate(ta_all, axis=1)
%time _, _, svd = hf.cleanTailAngles(ta_all, dt=1/500)

In [None]:
df_ctrl.shape, df_abl.shape

## *Generate probability maps*

In [None]:
imgs_prob = np.squeeze(unet_fsb.predict(imgs_rs[..., None], batch_size=6, verbose=1))

## *Make a movie to demonstrate segmentation*

In [None]:
alpha = 0.2
merge_ch = 0
fps = 50
cropSize = (256, 256)
iRange = (20, 300)
save=True

from skimage.color import gray2rgb

fp = fsb.track.findFish(-imgs_rs*imgs_prob, back_img=None)
fp_interp = spt.interp.nanInterp1d(fp)

inds = np.arange(*iRange)
imgs_rs_crop = volt.img.cropImgsAroundPoints(imgs_rs[inds], fp_interp[inds], cropSize=cropSize)
imgs_prob_crop = volt.img.cropImgsAroundPoints(imgs_prob[inds], fp_interp[inds], cropSize=cropSize)


imgs_prob_255 = (imgs_prob_crop*255).astype(int)
imgs_rs_rgb = np.array([gray2rgb(_, alpha=0.5) for _ in imgs_rs_crop])

imgs_rs_rgb[..., merge_ch] = (alpha*imgs_rs_rgb[..., merge_ch] + (1-alpha)*imgs_prob_255).astype(int) 

dir_save = os.path.join(dir_imgs, 'proc')
if not os.path.exists(dir_save):
    os.mkdir(dir_save)
fname = f'Tracking movie_trl[{iTrl}]_inTrlFrames[{iRange[0]}-{iRange[1]}]_imgDims[{cropSize[0]}x{cropSize[1]}]_{util.timestamp("minute")}.avi'
savePath = os.path.join(dir_save, fname)

ani =volt.animate_images(imgs_rs_rgb, fps=fps, fig_size=(15, 15), save=save, savePath=savePath)
print(f'Movie saved at\n{dir_save}\nas\n{fname}')
ani

### *Copy these images for training if performance not great*

In [None]:
imgs_rs[inds].shape
inds_mov = np.arange(37, len(inds))
np.random.shuffle(inds_mov)
# inds_mov = inds_mov[:10]
# savePath = os.path.join(dir_save, 'images_train_896x')
foo = fsb.copy_images_for_training(imgs_rs[inds][inds_mov], savePath=dir_save, nImgsToCopy=10)

In [None]:
fsb.copy_images_for_training?

## *From segmented fish images to tail curvature timeseries*

In [None]:
%time imgs_fish = fsb.fish_imgs_from_raw(imgs_rs, unet)[0]
%time midlines, inds_kept_midlines = fsb.track.midlines_from_binary_imgs(imgs_fish)
kappas = fsb.track.curvaturesAlongMidline(midlines, n=50)
tailAngles = np.cumsum(kappas, axis=0)
ta = hf.cleanTailAngles(tailAngles)[0]

## *Plot tail angles extracted from segmented fish*

In [None]:
#%% Plot tail angles

from matplotlib.colors import DivergingNorm
norm = DivergingNorm(0, vmin=-100, vmax=100)
fh, ax = plt.subplots(2,1, figsize=(20,10), sharex=True)

ax[0].imshow(ta[:, inds], aspect='auto', norm=norm, cmap='coolwarm', vmin=-100, vmax=100)
ax[0].set_yticks([0, 24, 49])
ax[0].set_yticklabels(['Head', 'Middle', 'Tail'])
ax[0].set_xticks([])
ax[0].set_title('Cumulative curvature along the tail')

ax[1].plot(ta[-1][inds])
ax[1].set_xlim(0, len(inds))
ax[1].set_xticks([0, len(inds)//2, len(inds)])
ax[1].set_xlabel('Image frame #')
ax[1].set_ylabel('Tail bend amplitude ($^o$)')
ax[1].set_title('Tail tail curvature timeseries');

# *Try Focal Loss*

In [None]:
#%% Instantiate U-net with focal loss specified during compilation
unet_fl = model.get_unet(img_width=896, img_height=896, img_channels=1, loss=model.focal_loss)

In [None]:
#%% Checkpointer callback for storing best weights
fp = os.path.join(dir_unet, f'best_weights_headFixed_{util.timestamp()}.hdf')
checkpointer = ModelCheckpoint(filepath=fp, monitor='val_dice_coef', verbose=1,\
                               save_best_only=True, mode='max', save_weights_only=True)

keras_callbacks = [checkpointer]

In [None]:
#%% Augment before training
upSample=4
aug_set=('rn', 'sig', 'log', 'inv', 'heq', 'rot', 'rs')
# aug_set=('rn', 'sig', 'log', 'inv', 'heq', 'rot')
%time imgs_aug, masks_aug, augs = mlearn.augmentImageData(imgs_train, masks_train,\
                                                          upsample=upSample, aug_set=aug_set)

imgs_aug = mlearn.prepare_imgs_for_unet(imgs_aug, unet)
masks_aug = mlearn.prepare_imgs_for_unet(masks_aug, unet)
print(f'Augmentation: {len(imgs_train)} --> {len(imgs_aug)}')

In [None]:
%%time
batch_size = 6 # For 1024 x 1024 images I can't help but use batch_size=6
epochs = 25
validation_split = 0.1
checkPoint = True

his = unet_fl.fit(imgs_aug, masks_aug, epochs=epochs, batch_size=batch_size,\
                   validation_split=validation_split, callbacks=keras_callbacks, verbose=1)


In [None]:
his = unet_fl.history.history
print(his.keys())
plt.figure(figsize=(15, 6))
plt.style.use(('seaborn-poster','fivethirtyeight', 'seaborn-white'))
plt.subplot(121)
plt.plot(his['val_dice_coef'],'.', label='validation set')
plt.plot(his['dice_coef'], label='training set')
plt.legend(fontsize=12)
plt.title('Dice coefficient', fontsize=14)

plt.subplot(122)
plt.plot(his['val_loss'],'.', label ='validation set')
plt.plot(his['loss'], label = 'training set')
plt.legend(fontsize=12)
plt.title('Foal loss ($\gamma = 2, unbalanced$)', fontsize=14);

In [None]:
imgs_prob = np.squeeze(unet_fl.predict(imgs_rs[..., None]))

In [None]:
alpha = 0.2
merge_ch = 0
fps = 50
inds = np.arange(450, 3000)
imgs_prob_255 = (imgs_prob*255).astype(int)
imgs_rs_rgb = np.array([gray2rgb(_, alpha=0.5) for _ in imgs_rs])

imgs_rs_rgb[..., merge_ch] = (alpha*imgs_rs_rgb[..., merge_ch] + (1-alpha)*imgs_prob_255).astype(int) 
ani =volt.animate_images(imgs_rs_rgb[inds], fps=fps, fig_size=(15, 15))
ani

In [None]:
%time imgs_fish = fsb.fish_imgs_from_raw(imgs_rs, unet_fl)[0]
%time midlines, inds_kept_midlines = fsb.track.midlines_from_binary_imgs(imgs_fish)
kappas = fsb.track.curvaturesAlongMidline(midlines, n=50)
tailAngles = np.cumsum(kappas, axis=0)
ta = hf.cleanTailAngles(tailAngles)[0]

In [None]:
#%% Plot tail angles

from matplotlib.colors import DivergingNorm
norm = DivergingNorm(0, vmin=-100, vmax=100)
fh, ax = plt.subplots(2,1, figsize=(20,10), sharex=True)

ax[0].imshow(ta[:, inds], aspect='auto', norm=norm, cmap='coolwarm', vmin=-100, vmax=100)
ax[0].set_yticks([0, 24, 49])
ax[0].set_yticklabels(['Head', 'Middle', 'Tail'])
ax[0].set_xticks([])
ax[0].set_title('Cumulative curvature along the tail')

ax[1].plot(ta[-1][inds])
ax[1].set_xlim(0, len(inds))
ax[1].set_xticks([0, len(inds)//2, len(inds)])
ax[1].set_xlabel('Image frame #')
ax[1].set_ylabel('Tail bend amplitude ($^o$)')
ax[1].set_title('Tail tail curvature timeseries');

## *Free swim behavior*