In [None]:
#%% Import relevant code
import os, sys, time
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import re
import dask
import h5py
import pandas as pd
import keras
from keras.callbacks import ModelCheckpoint
from keras.utils import multi_gpu_model
from skimage.util import montage
import glob

#--- Import my code
codeDir = r'\\dm11\koyamalab/code/python/code'
sys.path.append(codeDir)
# import apCode.FileTools as ft
import apCode.volTools as volt
from apCode.machineLearning import ml as mlearn
import apCode.SignalProcessingTools as spt
from apCode.machineLearning.unet import model
from apCode.behavior import FreeSwimBehavior as fsb
# from apCode import geom
import apCode.hdf as hdf
from apCode import util
from rsNeuronsProj import util as rsp
import apCode.behavior.headFixed as hf

#--- Setting seed for reproducability
seed = 143
np.random.seed = seed

plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42


#--- Auto-reload modules
try:
    if __IPYTHON__:
        get_ipython().magic('load_ext autoreload')
        get_ipython().magic('autoreload 2')
except NameError:
    pass

print(time.ctime())


## *Trying multi_gpu_model for the first time*

In [None]:
dir_unet = r'\\Koyama-S2\Data3\Avinash\U net'
path_unet = glob.glob(os.path.join(dir_unet, 'trainedU_headFixed*.h5'))[-1]

# unet = mlearn.loadPreTrainedUnet(path_unet)
# unet = model.get_unet(img_height=256, img_width=256, img_channels=1)
unet = model.get_unet_parallel(img_height=256, img_width=256)

### *Alternatively, load pre-trained model*

In [None]:
dir_unet = r'\\Koyama-S2\Data3\Avinash\U net'
path_unet = glob.glob(os.path.join(dir_unet, 'trainedU_multiGPU_headFixed*.h5'))[-1]
path_wts = glob.glob(os.path.join(dir_unet, 'best_weights_headFixed*.hdf'))[-1]
print(path_wts)
# unet = model.get_unet_parallel(img_height=256, img_width=256)
# unet.load_weights(path_wts)
unet = mlearn.loadPreTrainedUnet(path_unet)

In [None]:
dir_xls = r'\\Koyama-S2\Data3\Avinash\U net'
path_xls = glob.glob(os.path.join(dir_xls, 'Paths_to*Minoru*.xlsx'))[-1]
xls_train = pd.read_excel(path_xls, sheet_name='Uncropped')
xls_train = xls_train.loc[xls_train.exptType=='headFixed']

# changePath = lambda p: r'\\Koyama-S2\\Data3' + p.split(':')[-1] if p[1]==r":"
def changePath(p):
    if p[1] == r":":
        p = r'\\Koyama-S2\\Data3' + p.split(':')[-1]
    return p
        
paths_imgs = list(map(changePath, np.array(xls_train.pathToImages)))
path_masks = list(map(changePath, np.array(xls_train.pathToMasks)))
imgDims = unet.input_shape[1:3]
# imgs_train, masks_train = mlearn.read_training_images_and_masks(np.array(xls_train.pathToImages), 
#                                                     np.array(xls_train.pathToMasks), imgDims=imgDims)
imgs_train, masks_train = mlearn.read_training_images_and_masks(paths_imgs, 
                                                    path_masks, imgDims=imgDims)
masks_train = (masks_train>0).astype(int)
print(f'Training on {imgs_train.shape[0]} imgs of dimensions {imgs_train.shape[1:]}')


In [None]:
metrics = unet.evaluate(imgs_train[..., None], masks_train[..., None], batch_size=64, verbose=1)
print(np.c_[unet.metrics_names, metrics])

In [None]:
#%% Checkpointer callback for storing best weights
fp = os.path.join(dir_unet, f'best_weights_headFixed_{util.timestamp()}.hdf')
checkpointer = ModelCheckpoint(filepath=fp, monitor='val_dice_coef', verbose=1,\
                               save_best_only=True, mode='max', save_weights_only=True)

keras_callbacks = [checkpointer]
print(fp)

In [None]:
#%% Augment before training
upSample = 7 # This will expand the training set by this much
aug_set=('rn', 'sig', 'log', 'heq', 'rot', 'rs', 'inv')
%time imgs_aug, masks_aug, augs = mlearn.augmentImageData(imgs_train, masks_train,\
                                                          upsample=upSample, aug_set=aug_set)

imgs_aug = mlearn.prepare_imgs_for_unet(imgs_aug, unet)
masks_aug = mlearn.prepare_imgs_for_unet(masks_aug, unet)
masks_aug= (masks_aug>0).astype(int)
print(f'Augmentation: {len(imgs_train)} --> {len(imgs_aug)}')

In [None]:
metrics = unet.evaluate(imgs_aug, masks_aug, batch_size=64, verbose=1)
print(np.c_[unet.metrics_names, metrics])


In [None]:
%%time
batch_size = 64 # Larger batch sizes are usually better, but reduce if you get an OOM error
epochs = 350 # Number of training epochs
initial_epoch = 249
validation_split = 0.11 # Fraction of images from the training set to be used for validation

his = unet.fit(imgs_aug, masks_aug, epochs=epochs, batch_size=batch_size,\
               validation_split=validation_split, callbacks=keras_callbacks, 
               verbose=0, initial_epoch=initial_epoch)


### *Load best weights and save U net*

In [None]:
inputDims = unet.input_shape[1:3]
if isinstance(unet.loss, str):
    lf = unet.loss
else:
    lf = unet.loss.__name__
fn = f'trainedU_multiGPU_headFixed_{inputDims[0]}x{inputDims[1]}_{lf}_{util.timestamp()}.h5'

path_wts = glob.glob(os.path.join(dir_unet, 'best_weights_headFixed*.hdf'))[-1]
print(path_wts)
unet.load_weights(path_wts)

%time unet.save(os.path.join(dir_unet, fn))
print(fn)



In [None]:
metrics = unet.evaluate(imgs_aug, masks_aug, batch_size=32, verbose=1)
print(np.c_[unet.metrics_names, metrics])


In [None]:
plt.figure(figsize=(20, 5))
plt.subplot(121)
plt.plot(his.history['loss'])
plt.plot(his.history['val_loss'], '.')
plt.title('Loss vs epoch')

plt.subplot(122)
plt.plot(his.history['dice_coef'])
plt.plot(his.history['val_dice_coef'], '.')
plt.title('Accuracy vs epoch')

### *Read xls with paths to data*

In [None]:
dir_xls = r'\\Koyama-S2\Data3\Avinash\Projects\RS recruitment\GCaMP imaging'
dir_group = r'\\Koyama-S2\Data3\Avinash\Projects\RS recruitment\GCaMP imaging\Group'

file_xls = 'GCaMP volumetric imaging summary_2020-05-09.xlsx'
xls = pd.read_excel(os.path.join(dir_xls, file_xls), sheet_name='Sheet1')
print(xls.shape)
xls.head()

### *Go through all paths, check HDF and fill missing $Ca^{2+}$ or behavior variables*

In [None]:
patchPerc = (60, )
batch_size = 64


inds_fish = np.array(xls.FishIdx.dropna())
pathList = np.array([xls.loc[xls.FishIdx==ind].Path.iloc[0].replace('Y:','\\\\Koyama-S2\\Data3') for ind in inds_fish])
reTrackBehav = np.array(xls.reTrackBehav)
# inds_track = np.where(reTrackBehav==1)[0]

inds_track = np.arange(len(pathList))
pathList = pathList[inds_track]

print(f'{len(pathList)} paths in total')

for iPath, path_ in enumerate(pathList):
    track_behav = False
    track_ca = False
    reg_ca = False
    hfp = glob.glob(os.path.join(path_, 'procData*.h5'))
    if len(hfp)>0:
        hfp = hfp[-1]
        with h5py.File(hfp, mode='r') as hFile:
            if 'behav' not in hFile:
                track_behav = True
#             if 'ca_raw' not in hFile:
#                 track_ca = True
#             if 'ca_reg' not in hFile:
#                 reg_ca = True
            print(f'Path # {inds_track[iPath]+1}/{len(pathList)}')
#         if track_ca:
#             %time hfp_ = hf.read_and_store_ca_imgs(path_)
#         if reg_ca:
#             %time hfp_ = hf.register_piecewise_from_hdf(hfp_, patchPerc=patchPerc)[0]
        if track_behav:
            %time hfp_ = hf.extractAndStoreBehaviorData_singleFish(path_, uNet=unet, batch_size=batch_size)
    else:
        print(f'Path # {iPath+1}/{len(pathList)}')
#         %time hfp_ = hf.read_and_store_ca_imgs(path_)
#         %time hfp_ = hf.register_piecewise_from_hdf(hfp_)
        %time hfp_ = hf.extractAndStoreBehaviorData_singleFish(path_, uNet=unet, batch_size=batch_size)
    reTrackBehav[inds_track[iPath]] = 0
xls = xls.assign(reTrackBehav=reTrackBehav)

In [None]:
for path_ in pathList:
    hfp = glob.glob(os.path.join(path_, 'procData*.h5'))
    reg_ca = False
    if len(hfp)>0:
        hfp = hfp[-1]
        with h5py.File(hfp, mode='r') as hFile:
            if not 'ca_reg' in hFile:
                print(hfp)
                reg_ca=True
        if reg_ca:
            hfp_ = hf.register_piecewise_from_hdf(hfp)
    else:
        hfp_ = hf.register_piecewise_from_hdf(hfp)   
    
            


In [None]:
path_ = pathList[18]
print(path_)
hfp = glob.glob(os.path.join(path_, 'procData*.h5'))[-1]
# %time hfp_ = hf.extractAndStoreBehaviorData_singleFish(path_, uNet=unet, batch_size=64)
with h5py.File(hfp, mode='r') as hFile:
    print(hFile['behav'].keys())
    ta = np.array(hFile['behav/tailAngles'])
    stimLoc = np.array(hFile['behav/stimLoc'])
    stimLoc = util.to_utf(stimLoc)    
nTrls = ta.shape[0]//50
ta_trl = np.vsplit(ta, nTrls)
ta = np.concatenate(ta_trl, axis=1)
trlLen = ta_trl[0].shape[1]
print(f'{nTrls} trls  of {trlLen} length')
%time ta_clean = hf.cleanTailAngles(ta, dt=1/500, nWaves=10)[0]

In [None]:
plt.figure(figsize=(20, 5))
t = np.arange(ta.shape[1])*(1/500)

stimInds = (np.arange(nTrls)*(trlLen) + 500).astype(int)
stimTimes = t[stimInds]
inds_head = util.findStrInList('h', stimLoc)
# plt.plot(t, ta[-1])
plt.plot(t, ta_clean[-1], c='k', alpha=0.75)
for ind, st in enumerate(stimTimes):
    if ind in inds_head:
        plt.axvline(st, ls='--', c=plt.cm.tab10(0), alpha=0.8)
    else:
        plt.axvline(st, ls='--', c=plt.cm.tab10(1), alpha=0.8)
plt.xlim(0, t[-1])
# plt.xlim(0, 50)
print(path_)

### *Look at some visually confirmed noisy trls and retrack as need be*

In [None]:
from caiman import movie
import joblib

dir_ = r'Y:\Avinash\Projects\RS recruitment\GCaMP imaging\Group\Figs\Trials with GMM labels'
fn = 'noisyTrlPaths.npy'
noisy = np.load(os.path.join(dir_, fn), allow_pickle=True)[()]

dir_ = r'Y:\Avinash\Projects\RS recruitment\GCaMP imaging\Group'
path_gmm = glob.glob(os.path.join(dir_, 'gmm_headFixed_*.pkl'))[-1]
gmm_model = joblib.load(path_gmm)


In [None]:
ind = 36
trlDir = noisy['trlDir'][ind]
trl = noisy['trlIdx_glob'][ind]
print(f'Trl # {trl} \n{trlDir}')
imgs = volt.img.readImagesInDir(trlDir)


In [None]:
# movie(imgs, fr=100).play(magnification=2)

In [None]:
# imgs_prob = fsb.prob_images_with_unet(imgs, unet, batch_size=64, verbose=1)
%time out = hf.tailAnglesFromRawImagesUsingUnet(imgs, unet, batch_size=64, verbose=1, prob_thr=0.3)


In [None]:
ta = out['tailAngles']
ml = out['midlines']
inds_kept = out['inds_kept_midlines']
inds_lost = np.setdiff1d(range(ta.shape[1]), inds_kept)

imgs_mid = imgs.copy()
for ind, img in enumerate(imgs_mid):
    if ind in inds_kept:
        ml_= tuple(np.fliplr(ml[ind]).astype(int).T)
        imgs_mid[ind][ml_]=0
        
%time ta = hf.cleanTailAngles(ta, svd=gmm_model.svd, nWaves=3)[0]

In [None]:
plt.figure(figsize=(20, 4))
t = np.arange(ta.shape[1])/500
plt.plot(t, ta[-1])
plt.plot(t[inds_lost], ta[-1][inds_lost], 'o')
plt.xlim(0, t[-1])
plt.ylim(np.minimum(-150, ta[-1].min()), np.maximum(150, ta[-1].max()))
# plt.xlim(0.8, 3)

In [None]:
figDir = r'Y:\Avinash\Projects\RS recruitment\GCaMP imaging\Group\Figs\Trials with GMM labels\noisy\after_codeAndTrack_fix_ml'
os.makedirs(figDir, exist_ok=True)
title = f'Trl-{trl} with GMM labels.html'
fig = gmm_model.plot_with_labels_interact(ta, x=t, title=title)
fig.write_html(os.path.join(figDir, f'Trl-{trl}_with GMM labels.html'))
print(f'Trl {trl} saved at \n{figDir}')

In [None]:
fr = 75
# mov = movie(imgs*(1-out['images_prob']).astype('float32'))
mov = movie(imgs_mid)
# mov.play(magnification=3)

In [None]:
mov.save(os.path.join(figDir, f'Trl-{trl}_tail_prob_movie_trl.avi'))
print(trlDir)

In [None]:
saveDir = r'\\Koyama-S2\Data3\Avinash\Head-fixed tail free\GCaMP imaging\2019-12-18\f1'
nImgsToCopy=10
tRange = (1.1, 1.25)
# inds = np.where((t>=tRange[0]) & (t<=tRange[1]))[0]
inds = np.array([716, 717, 719, 723, 742, 743, 748, 881])-1

foo = fsb.copy_images_for_training(imgs[inds], nImgsToCopy=nImgsToCopy, savePath=saveDir, 
                                   detect_motion_frames=False)


In [None]:
ind = np.random.choice(imgs_aug.shape[0], size=1)[0]
img = np.squeeze(imgs_aug[ind])
mask = np.squeeze(masks_aug[ind])
m = montage((img, mask), grid_shape=(1, 2), rescale_intensity=True)
plt.figure(figsize=(10, 10))
plt.imshow(m)
plt.title(f'Img # {ind}, aug= {augs[ind]}')

In [None]:
# fsb.tail_angles_from_raw_imgs_using_unet?

In [None]:
iTrl = 3
imgDir = r'\\Koyama-S2\Data3\Avinash\Head-fixed tail free\GCaMP imaging\2020-01-22_nefma\f1\002_t\behav'
savePath = r'\\Koyama-S2\Data3\Avinash\Head-fixed tail free\GCaMP imaging\2020-01-22_nefma\f1'
import apCode.FileTools as ft
subDirs = [os.path.join(imgDir, sd) for sd in ft.subDirsInDir(imgDir)]
sd = subDirs[iTrl]

foo = fsb.copy_images_for_training(sd, nImgsToCopy=2, savePath=savePath)

In [None]:
path_imgs = r'\\Koyama-S2\Data3\Avinash\Head-fixed tail free\GCaMP imaging\2020-01-19_nefma\f1\001_h\behav\Autosave0_[00-11-1c-f1-75-10]_20200119_075226_PM'



# %time imgs = volt.img.readImagesInDir(path_imgs)

# %time imgs_fish, imgs_prob = fsb.fish_imgs_from_raw(imgs, unet)

# %time ml = hf.midlinesFromImages(imgs_fish*0)[0]
%time ml, inds_kept = fsb.track.midlines_from_binary_imgs(imgs_fish)


In [None]:
%time kappas = fsb.track.curvaturesAlongMidline(ml)
ta = np.cumsum(kappas, axis=0)

In [None]:
from apCode.SignalProcessingTools import interp

In [None]:
ta_interp = np.ones((ta.shape[0], imgs_fish.shape[0]))*np.nan
ta_interp[:, inds_kept] = ta
%time ta_interp = spt.interp.nanInterp2d(ta_interp, method='nearest')

In [None]:
%time imgs_prob = unet.predict(fsb.prepareForUnet_1ch(imgs, sz=uShape), batch_size=64)
imgs_prob = np.squeeze(imgs_prob)
%time imgs_prob = volt.img.resize(imgs_prob, imgs.shape[1:], preserve_dtype=True, preserve_range=True)

In [None]:
%time imgs_fish = hf.fishImgsForMidline(imgs_prob, filtSize=2.5, otsuMult=1)

In [None]:
%time imgs_fish2, imgs_prob2 = fsb.fish_imgs_from_raw(imgs, unet, batch_size=64)


In [None]:
ani = volt.animate_images(imgs_fish[500:2000])
ani

In [None]:
ani2 = volt.animate_images(imgs_fish2[500:2000])
ani2

In [None]:
%time midlines = midlinesFromImages(imgs_fish)[0]
%time midlines_interp = geom.interpolateCurvesND(midlines, mode='2D', N=50)
if verbose:
    print('Curve smoothening...')
midlines_interp = np.asarray(compute(*[delayed(geom.smoothen_curve)(_, smooth_fixed=smooth)\ 
                                       for _ in midlines_interp], scheduler='processes'))
midlines_interp = geom.equalizeCurveLens(midlines_interp)

In [None]:
%time dic_ta = hf.tailAngles_from_hdf_concatenated_by_trials(pathList)
ta = np.concatenate(dic_ta['tailAngles'], axis = 1)


In [None]:
%time ta_clean, _, svd = hf.cleanTailAngles(ta)

In [None]:
# plt.figure(figsize=(20, 5))
# t = np.arange(ta_clean.shape[1])*(1/500)
# plt.plot(t, ta[-1])
# plt.xlim(20, 50)
# plt.ylim(20, 25)

## *Read dataframe with all relevant information (paths, etc)*

In [None]:
#%% Path to excel sheet storing paths to data and other relevant info
dir_df = r'Y:\Avinash\Projects\RS recruitment\Ablations\session_20200422-00'
path_df = glob.glob(os.path.join(dir_df, 'dataFrame_rsNeurons_ablations_svdClean_2020*.pkl'))[-1]

df = pd.read_pickle(path_df)
dir_save = os.path.join(dir_df, f'session_{util.timestamp()}')
os.makedirs(dir_save, exist_ok=True)

print(df.columns)

## *Evaluate pre-training performance*

In [None]:
ta_all = [np.array(ta_) for ta_ in df['tailAngles']]
ta_all = np.concatenate(ta_all, axis=1)
%time _, _, svd = hf.cleanTailAngles(ta_all, dt=1/500)

In [None]:
df_ctrl.shape, df_abl.shape

## *Generate probability maps*

In [None]:
imgs_prob = np.squeeze(unet_fsb.predict(imgs_rs[..., None], batch_size=6, verbose=1))

## *Make a movie to demonstrate segmentation*

In [None]:
alpha = 0.2
merge_ch = 0
fps = 50
cropSize = (256, 256)
iRange = (20, 300)
save=True

from skimage.color import gray2rgb

fp = fsb.track.findFish(-imgs_rs*imgs_prob, back_img=None)
fp_interp = spt.interp.nanInterp1d(fp)

inds = np.arange(*iRange)
imgs_rs_crop = volt.img.cropImgsAroundPoints(imgs_rs[inds], fp_interp[inds], cropSize=cropSize)
imgs_prob_crop = volt.img.cropImgsAroundPoints(imgs_prob[inds], fp_interp[inds], cropSize=cropSize)


imgs_prob_255 = (imgs_prob_crop*255).astype(int)
imgs_rs_rgb = np.array([gray2rgb(_, alpha=0.5) for _ in imgs_rs_crop])

imgs_rs_rgb[..., merge_ch] = (alpha*imgs_rs_rgb[..., merge_ch] + (1-alpha)*imgs_prob_255).astype(int) 

dir_save = os.path.join(dir_imgs, 'proc')
if not os.path.exists(dir_save):
    os.mkdir(dir_save)
fname = f'Tracking movie_trl[{iTrl}]_inTrlFrames[{iRange[0]}-{iRange[1]}]_imgDims[{cropSize[0]}x{cropSize[1]}]_{util.timestamp("minute")}.avi'
savePath = os.path.join(dir_save, fname)

ani =volt.animate_images(imgs_rs_rgb, fps=fps, fig_size=(15, 15), save=save, savePath=savePath)
print(f'Movie saved at\n{dir_save}\nas\n{fname}')
ani

### *Copy these images for training if performance not great*

In [None]:
imgs_rs[inds].shape
inds_mov = np.arange(37, len(inds))
np.random.shuffle(inds_mov)
# inds_mov = inds_mov[:10]
# savePath = os.path.join(dir_save, 'images_train_896x')
foo = fsb.copy_images_for_training(imgs_rs[inds][inds_mov], savePath=dir_save, nImgsToCopy=10)

In [None]:
fsb.copy_images_for_training?

## *From segmented fish images to tail curvature timeseries*

In [None]:
%time imgs_fish = fsb.fish_imgs_from_raw(imgs_rs, unet)[0]
%time midlines, inds_kept_midlines = fsb.track.midlines_from_binary_imgs(imgs_fish)
kappas = fsb.track.curvaturesAlongMidline(midlines, n=50)
tailAngles = np.cumsum(kappas, axis=0)
ta = hf.cleanTailAngles(tailAngles)[0]

## *Plot tail angles extracted from segmented fish*

In [None]:
#%% Plot tail angles

from matplotlib.colors import DivergingNorm
norm = DivergingNorm(0, vmin=-100, vmax=100)
fh, ax = plt.subplots(2,1, figsize=(20,10), sharex=True)

ax[0].imshow(ta[:, inds], aspect='auto', norm=norm, cmap='coolwarm', vmin=-100, vmax=100)
ax[0].set_yticks([0, 24, 49])
ax[0].set_yticklabels(['Head', 'Middle', 'Tail'])
ax[0].set_xticks([])
ax[0].set_title('Cumulative curvature along the tail')

ax[1].plot(ta[-1][inds])
ax[1].set_xlim(0, len(inds))
ax[1].set_xticks([0, len(inds)//2, len(inds)])
ax[1].set_xlabel('Image frame #')
ax[1].set_ylabel('Tail bend amplitude ($^o$)')
ax[1].set_title('Tail tail curvature timeseries');

# *Try Focal Loss*

In [None]:
#%% Instantiate U-net with focal loss specified during compilation
unet_fl = model.get_unet(img_width=896, img_height=896, img_channels=1, loss=model.focal_loss)

In [None]:
#%% Checkpointer callback for storing best weights
fp = os.path.join(dir_unet, f'best_weights_headFixed_{util.timestamp()}.hdf')
checkpointer = ModelCheckpoint(filepath=fp, monitor='val_dice_coef', verbose=1,\
                               save_best_only=True, mode='max', save_weights_only=True)

keras_callbacks = [checkpointer]

In [None]:
#%% Augment before training
upSample=4
aug_set=('rn', 'sig', 'log', 'inv', 'heq', 'rot', 'rs')
# aug_set=('rn', 'sig', 'log', 'inv', 'heq', 'rot')
%time imgs_aug, masks_aug, augs = mlearn.augmentImageData(imgs_train, masks_train,\
                                                          upsample=upSample, aug_set=aug_set)

imgs_aug = mlearn.prepare_imgs_for_unet(imgs_aug, unet)
masks_aug = mlearn.prepare_imgs_for_unet(masks_aug, unet)
print(f'Augmentation: {len(imgs_train)} --> {len(imgs_aug)}')

In [None]:
%%time
batch_size = 6 # For 1024 x 1024 images I can't help but use batch_size=6
epochs = 25
validation_split = 0.1
checkPoint = True

his = unet_fl.fit(imgs_aug, masks_aug, epochs=epochs, batch_size=batch_size,\
                   validation_split=validation_split, callbacks=keras_callbacks, verbose=1)


In [None]:
his = unet_fl.history.history
print(his.keys())
plt.figure(figsize=(15, 6))
plt.style.use(('seaborn-poster','fivethirtyeight', 'seaborn-white'))
plt.subplot(121)
plt.plot(his['val_dice_coef'],'.', label='validation set')
plt.plot(his['dice_coef'], label='training set')
plt.legend(fontsize=12)
plt.title('Dice coefficient', fontsize=14)

plt.subplot(122)
plt.plot(his['val_loss'],'.', label ='validation set')
plt.plot(his['loss'], label = 'training set')
plt.legend(fontsize=12)
plt.title('Foal loss ($\gamma = 2, unbalanced$)', fontsize=14);

In [None]:
imgs_prob = np.squeeze(unet_fl.predict(imgs_rs[..., None]))

In [None]:
alpha = 0.2
merge_ch = 0
fps = 50
inds = np.arange(450, 3000)
imgs_prob_255 = (imgs_prob*255).astype(int)
imgs_rs_rgb = np.array([gray2rgb(_, alpha=0.5) for _ in imgs_rs])

imgs_rs_rgb[..., merge_ch] = (alpha*imgs_rs_rgb[..., merge_ch] + (1-alpha)*imgs_prob_255).astype(int) 
ani =volt.animate_images(imgs_rs_rgb[inds], fps=fps, fig_size=(15, 15))
ani

In [None]:
%time imgs_fish = fsb.fish_imgs_from_raw(imgs_rs, unet_fl)[0]
%time midlines, inds_kept_midlines = fsb.track.midlines_from_binary_imgs(imgs_fish)
kappas = fsb.track.curvaturesAlongMidline(midlines, n=50)
tailAngles = np.cumsum(kappas, axis=0)
ta = hf.cleanTailAngles(tailAngles)[0]

In [None]:
#%% Plot tail angles

from matplotlib.colors import DivergingNorm
norm = DivergingNorm(0, vmin=-100, vmax=100)
fh, ax = plt.subplots(2,1, figsize=(20,10), sharex=True)

ax[0].imshow(ta[:, inds], aspect='auto', norm=norm, cmap='coolwarm', vmin=-100, vmax=100)
ax[0].set_yticks([0, 24, 49])
ax[0].set_yticklabels(['Head', 'Middle', 'Tail'])
ax[0].set_xticks([])
ax[0].set_title('Cumulative curvature along the tail')

ax[1].plot(ta[-1][inds])
ax[1].set_xlim(0, len(inds))
ax[1].set_xticks([0, len(inds)//2, len(inds)])
ax[1].set_xlabel('Image frame #')
ax[1].set_ylabel('Tail bend amplitude ($^o$)')
ax[1].set_title('Tail tail curvature timeseries');

## *Free swim behavior*