In [None]:
#%% Import relevant code
import os, sys, time
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import re as regex
import dask
import h5py
import pandas as pd
import keras
from keras.callbacks import ModelCheckpoint, TensorBoard
from keras import models
from skimage.util import montage
import glob

#--- Import my code
codeDir = r'V:/code/python/code'
sys.path.append(codeDir)
import apCode.FileTools as ft
import apCode.volTools as volt
from apCode.machineLearning import ml as mlearn
import apCode.SignalProcessingTools as spt
from apCode.machineLearning.unet import model
from apCode.behavior import FreeSwimBehavior as fsb
from apCode import geom
import apCode.hdf as hdf
from apCode import util
from rsNeuronsProj import util as rsp
import apCode.behavior.headFixed as hf

#--- Setting seed for reproducability
seed = 143
np.random.seed = seed

plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42


#--- Auto-reload modules
try:
    if __IPYTHON__:
        get_ipython().magic('load_ext autoreload')
        get_ipython().magic('autoreload 2')
except NameError:
    pass

print(time.ctime())


## *Load 2 U nets, one trained only on free swim diffuse light data and one trained on all types including head fixed and collimated light conditions*

In [None]:
dir_unet = r'Y:\Avinash\Ablations and Behavior'
learning_rate = 'default'

path_unet_fsb = glob.glob(os.path.join(dir_unet, 'trainedU_fsb_896*.h5'))[-1]
path_unet_all = glob.glob(os.path.join(dir_unet, 'trainedU_fsb_collimated_headFixed_*.h5'))[-1]

print(path_unet_fsb + '\n' + path_unet_all)
unet_fsb = mlearn.loadPreTrainedUnet(path_unet_fsb) # Load pre-trained 
unet_all = mlearn.loadPreTrainedUnet(path_unet_all) # Load pre-trained 

## *Read dataframe with all relevant information (paths, etc)*

In [None]:
#%% Path to excel sheet storing paths to data and other relevant info
dir_df = r'Y:\Avinash\Projects\RS recruitment\Ablations\session_20200422-00'
path_df = glob.glob(os.path.join(dir_df, 'dataFrame_rsNeurons_ablations_svdClean_2020*.pkl'))[-1]

df = pd.read_pickle(path_df)
dir_save = os.path.join(dir_df, f'session_{util.timestamp()}')
os.makedirs(dir_save, exist_ok=True)

print(df.columns)

## *Evaluate pre-training performance*

In [None]:
ta_all = [np.array(ta_) for ta_ in df['tailAngles']]
ta_all = np.concatenate(ta_all, axis=1)
%time _, _, svd = hf.cleanTailAngles(ta_all, dt=1/500)

In [None]:
df_ctrl.shape, df_abl.shape

In [None]:
iFish = (2, 2)
iTrl = (2, 4)
trlLen=750
ablGrp = 'mHom'  # ['mHom', 'intermediateRS', 'ventralRS']

plt.style.use(('seaborn-white', 'seaborn-ticks', 'fivethirtyeight', 'seaborn-talk'))

df_ctrl = df.loc[(df.AblationGroup==ablGrp) & (df.Treatment=='ctrl')]
df_abl = df.loc[(df.AblationGroup==ablGrp) & (df.Treatment=='abl')]

ctrl = df_ctrl.loc[df_ctrl.FishIdx == df_ctrl.FishIdx.iloc[iFish[0]]]
abl = df_abl.loc[df_abl.FishIdx == df_abl.FishIdx.iloc[iFish[1]]]

ta_ctrl = np.array(ctrl['tailAngles'].iloc[0])
ta_abl = np.array(abl['tailAngles'].iloc[0])

r_ctrl = fsb.track.assessTracking(ta_ctrl)
r_abl = fsb.track.assessTracking(ta_abl)

ta_ctrl = np.array(np.hsplit(ta_ctrl, ta_ctrl.shape[1]/750))
ta_abl = np.array(np.hsplit(ta_abl, ta_abl.shape[1]/750))

# inds_sort_ctrl = np.argsort(r_ctrl)
# inds_sort_abl = np.argsort(r_abl)
inds_sort_ctrl = np.arange(ta_ctrl.shape[0])
inds_sort_abl = np.arange(ta_abl.shape[0])

path_ctrl = glob.glob(ctrl.Path.iloc[0] +  '/procData*.h5')[-1]
path_abl = glob.glob(abl.Path.iloc[0] +  '/procData*.h5')[-1]

plt.figure(figsize=(20, 5))

trl_ctrl = inds_sort_ctrl[iTrl[0]]
y_ctrl = ta_ctrl[trl_ctrl][-1]
t = (np.arange(len(y_ctrl))-50)*(1/500)*1000
plt.plot(t, y_ctrl, label='Ctrl')

trl_abl = inds_sort_abl[iTrl[1]]
y_abl = ta_abl[trl_abl][-1]
plt.plot(t, y_abl, label='Abl')
plt.legend(loc='best', fontsize=16)
plt.xlim(-30, 500)
print(trl_abl)

In [None]:
dir_now = path_abl
trl_abl = trl_abl

dir_imgs_abl = rsp.remove_suffix_from_paths(os.path.split(dir_now)[0])[()]
trlInds = np.arange(trl_abl*trlLen, (trl_abl+1)*trlLen)
%time imgs_abl_raw = volt.dask_array_from_image_sequence(dir_imgs_abl)[trlInds].compute()

print('Prob images...')
%time imgs_fish, imgs_prob = fsb.fish_imgs_from_raw(imgs_abl_raw, unet_fsb, batch_size=6)
# img_back = fsb.track.computeBackground(dir_imgs_abl)
# imgs_fish, imgs_prob = fsb.fish_imgs_from_raw(imgs_abl_raw-img_back, unet_fsb, batch_size=6)


print('Tail angles...')
%time ml = fsb.track.midlines_from_binary_imgs(imgs_fish)[0]
%time ta = fsb.track.curvaturesAlongMidline(ml)
ta = np.cumsum(ta, axis=0)
ta = hf.cleanTailAngles(ta, svd=svd, nWaves=2, dt=1/500)[0]
y = ta[-1]


plt.figure(figsize=(20, 5))
t = (np.arange(len(y))-50)*(1/500)*1000
plt.plot(t, y)
plt.xlim(-30, 500)

print('Cropping...')
out = fsb.track.find_and_crop_imgs_around_fish(-imgs_abl_raw*imgs_fish, cropSize=(150, 150))
fishPos, imgs_crop = out['fishPos'], out['imgs_crop']


In [None]:
# %time ml = fsb.track.midlines_from_binary_imgs(imgs_fish)[0]

In [None]:
xl = (-20, 500)
save = False
sfx = 'largeSlowCounterBend'

t = (np.arange(len(y))-50)*(1/500)*1000
inds = np.where((t>=xl[0]) & (t<=xl[1]))[0]

savePath = os.path.join(dir_save, f'Movie-{util.timestamp("second")}_{sfx}.avi')
%time ani = hf.see_behavior_with_labels(imgs_crop[inds], y[inds], savePath=savePath, save=save)
# %time ani = hf.see_behavior_with_labels(imgs_abl_raw[inds], y_abl[inds], savePath=savePath, save=save)
ani

### *Additional training, if need be*

In [None]:
tRange = (-100, 300)
nImgsToCopy=10

t = (np.arange(imgs_abl_raw.shape[0])-40)*(1/500)*1000
iRange = np.where((t>=tRange[0]) & (t<=tRange[-1]))[0]
iRange = (iRange[0], iRange[-1])
# inds = np.random.randint(*iRange, nImgsToCopy)
inds = np.random.choice(np.arange(*iRange), size=nImgsToCopy, replace=False)
dir_imgs_train = os.path.split(dir_now)[0]
os.makedirs(dir_imgs_train, exist_ok=True)
%time foo = fsb.copy_images_for_training(imgs_abl_raw[inds], savePath=dir_imgs_train, detect_motion_frames=False);



In [None]:
%%time
dir_xls_train = r'Y:\Avinash\Ablations and Behavior'
file_xls_train = 'Paths_to_fish_training_images.xlsx'
sheet_name = 'Uncropped'
xls_train = pd.read_excel(os.path.join(dir_xls_train, file_xls_train), sheet_name=sheet_name)
xls_train = xls_train.loc[xls_train.exptType=='fsb']

imgDims = unet_fsb.input_shape[1:3]
imgs_train, masks_train = mlearn.read_training_images_and_masks(np.array(xls_train.pathToImages), 
                                                    np.array(xls_train.pathToMasks), imgDims=imgDims)
masks_train = (masks_train>0).astype(int)
print(f'Training on {imgs_train.shape[0]} of dimensions {imgs_train.shape[1:]}')

metrics = unet_fsb.evaluate(imgs_train[..., None], masks_train[..., None], batch_size=6, verbose=1)
print(np.c_[unet_fsb.metrics_names, metrics])

### *Check pointer*

In [None]:
#%% Checkpointer callback for storing best weights
fp = os.path.join(dir_unet, f'best_weights_fsb_{util.timestamp()}.hdf')
checkpointer = ModelCheckpoint(filepath=fp, monitor='val_dice_coef', verbose=1,\
                               save_best_only=True, mode='max', save_weights_only=True)

keras_callbacks = [checkpointer]

### _Augmentation_

In [None]:
#%% Augment before training
upSample=5
aug_set=('rn', 'sig', 'log', 'inv', 'heq', 'rot', 'rs')
%time imgs_aug, masks_aug, augs = mlearn.augmentImageData(imgs_train, masks_train,\
                                                          upsample=upSample, aug_set=aug_set)

imgs_aug = mlearn.prepare_imgs_for_unet(imgs_aug, unet_fsb)
masks_aug = mlearn.prepare_imgs_for_unet(masks_aug, unet_fsb)
masks_aug = (masks_aug>0).astype(int)
print(f'Augmentation: {len(imgs_train)} --> {len(imgs_aug)}')

metrics = unet_fsb.evaluate(imgs_aug, masks_aug, batch_size=6, verbose=1)
print(np.c_[unet_fsb.metrics_names, metrics])


### *Run cell below if a new model is to be instantiated*

In [None]:
learning_rate = 'default'

if learning_rate is 'default':
    print('Default learning rate')
#     optimizer = keras.optimizers.rmsprop()
    optimizer = keras.optimizers.adam()
else:
#     optimizer = keras.optimizers.rmsprop(learning_rate=learning_rate)
    optimizer = keras.optimizers.adam(learning_rate=learning_rate)
    
unet_fsb = model.get_unet(img_width=896, img_height=896, img_channels=1, optimizer=optimizer,
                          loss=model.focal_loss)
file_weights = os.path.join(dir_unet, 'best_weights_fsb_20200404-17.hdf')

unet_fsb.load_weights(file_weights)

In [None]:
%%time
batch_size = 6 # For 1024 x 1024 images I can't help but use this size
epochs = 150
validation_split = 0.1
checkPoint = True

his = unet_fsb.fit(imgs_aug, masks_aug, epochs=epochs, batch_size=batch_size,\
                   validation_split=validation_split, callbacks=keras_callbacks, verbose=1)




### *Load the best weights from the checkpointed file and save the U net*

In [None]:
#%% Load the best weights and save
fn = ft.findAndSortFilesInDir(dir_unet, search_str='best_weights_fsb_2020', ext='hdf')[-1]
print(fn)
unet_fsb.load_weights(os.path.join(dir_unet, fn))

#%% Save the U-net
# dir_unet = r'Y:\Avinash\Ablations and Behavior'
fn = f'trainedU_fsb_{unet_fsb.input_shape[1]}x{unet_fsb.input_shape[2]}_{util.timestamp("minute")}.h5'
unet_fsb.save(os.path.join(dir_unet, fn))
print(time.ctime())

In [None]:
his = unet_fsb.history.history
print(his.keys())
plt.style.use(('seaborn-white', 'seaborn-ticks', 'seaborn-talk', 'fivethirtyeight'))
plt.figure(figsize=(15, 6))
plt.subplot(121)
plt.plot(his['val_dice_coef'],'.', label='validation set')
plt.plot(his['dice_coef'], label='training set')
plt.legend(fontsize=12)
plt.title('Dice coefficient', fontsize=14)

plt.subplot(122)
plt.plot(his['val_loss'],'.', label ='validation set')
plt.plot(his['loss'], label = 'training set')
plt.legend(fontsize=12)
plt.title('Binary cross-entropy loss', fontsize=14)

### *Evaluate metrics again*

In [None]:
metrics = unet_fsb.evaluate(imgs_aug, masks_aug, batch_size=6, verbose=1)
print(np.c_[unet_fsb.metrics_names, metrics])

## *Randomly select and plot images and masks for checking* 

In [None]:
inds = np.random.permutation(np.arange(imgs_train.shape[0]))
ind=inds[0]
m = montage((imgs_train[ind], masks_train[ind]), rescale_intensity=True, grid_shape=(1,2))
plt.figure(figsize=(20, 20))
plt.imshow(m, cmap='viridis')


## *Read a contiguous set of unseen images for predicting with U-net*

In [None]:
iFish = 2
iTrl = 7
trlLen = 750

fishInds = np.unique(df.FishIdx)
dir_imgs= rsp.remove_suffix_from_paths(df.loc[df.FishIdx==fishInds[iFish]].iloc[0].Path)[()]

trlInds = np.arange(iTrl*trlLen, (iTrl+1)*trlLen)
%time imgPaths = [os.path.join(dir_imgs, _) for _ in ft.findAndSortFilesInDir(dir_imgs, ext='bmp')[trlInds]]


imgs = volt.img.readImagesInDir(imgPaths=imgPaths)
imgs_rs = volt.img.resize(imgs, unet_fsb.input_shape[1:3])

## *Generate probability maps*

In [None]:
imgs_prob = np.squeeze(unet_fsb.predict(imgs_rs[..., None], batch_size=6, verbose=1))

## *Make a movie to demonstrate segmentation*

In [None]:
alpha = 0.2
merge_ch = 0
fps = 50
cropSize = (256, 256)
iRange = (20, 300)
save=True

from skimage.color import gray2rgb

fp = fsb.track.findFish(-imgs_rs*imgs_prob, back_img=None)
fp_interp = spt.interp.nanInterp1d(fp)

inds = np.arange(*iRange)
imgs_rs_crop = volt.img.cropImgsAroundPoints(imgs_rs[inds], fp_interp[inds], cropSize=cropSize)
imgs_prob_crop = volt.img.cropImgsAroundPoints(imgs_prob[inds], fp_interp[inds], cropSize=cropSize)


imgs_prob_255 = (imgs_prob_crop*255).astype(int)
imgs_rs_rgb = np.array([gray2rgb(_, alpha=0.5) for _ in imgs_rs_crop])

imgs_rs_rgb[..., merge_ch] = (alpha*imgs_rs_rgb[..., merge_ch] + (1-alpha)*imgs_prob_255).astype(int) 

dir_save = os.path.join(dir_imgs, 'proc')
if not os.path.exists(dir_save):
    os.mkdir(dir_save)
fname = f'Tracking movie_trl[{iTrl}]_inTrlFrames[{iRange[0]}-{iRange[1]}]_imgDims[{cropSize[0]}x{cropSize[1]}]_{util.timestamp("minute")}.avi'
savePath = os.path.join(dir_save, fname)

ani =volt.animate_images(imgs_rs_rgb, fps=fps, fig_size=(15, 15), save=save, savePath=savePath)
print(f'Movie saved at\n{dir_save}\nas\n{fname}')
ani

### *Copy these images for training if performance not great*

In [None]:
imgs_rs[inds].shape
inds_mov = np.arange(37, len(inds))
np.random.shuffle(inds_mov)
# inds_mov = inds_mov[:10]
# savePath = os.path.join(dir_save, 'images_train_896x')
foo = fsb.copy_images_for_training(imgs_rs[inds][inds_mov], savePath=dir_save, nImgsToCopy=10)

In [None]:
fsb.copy_images_for_training?

## *From segmented fish images to tail curvature timeseries*

In [None]:
%time imgs_fish = fsb.fish_imgs_from_raw(imgs_rs, unet)[0]
%time midlines, inds_kept_midlines = fsb.track.midlines_from_binary_imgs(imgs_fish)
kappas = fsb.track.curvaturesAlongMidline(midlines, n=50)
tailAngles = np.cumsum(kappas, axis=0)
ta = hf.cleanTailAngles(tailAngles)[0]

## *Plot tail angles extracted from segmented fish*

In [None]:
#%% Plot tail angles

from matplotlib.colors import DivergingNorm
norm = DivergingNorm(0, vmin=-100, vmax=100)
fh, ax = plt.subplots(2,1, figsize=(20,10), sharex=True)

ax[0].imshow(ta[:, inds], aspect='auto', norm=norm, cmap='coolwarm', vmin=-100, vmax=100)
ax[0].set_yticks([0, 24, 49])
ax[0].set_yticklabels(['Head', 'Middle', 'Tail'])
ax[0].set_xticks([])
ax[0].set_title('Cumulative curvature along the tail')

ax[1].plot(ta[-1][inds])
ax[1].set_xlim(0, len(inds))
ax[1].set_xticks([0, len(inds)//2, len(inds)])
ax[1].set_xlabel('Image frame #')
ax[1].set_ylabel('Tail bend amplitude ($^o$)')
ax[1].set_title('Tail tail curvature timeseries');

# *Try Focal Loss*

In [None]:
#%% Instantiate U-net with focal loss specified during compilation
unet_fl = model.get_unet(img_width=896, img_height=896, img_channels=1, loss=model.focal_loss)

In [None]:
#%% Checkpointer callback for storing best weights
fp = os.path.join(dir_unet, f'best_weights_headFixed_{util.timestamp()}.hdf')
checkpointer = ModelCheckpoint(filepath=fp, monitor='val_dice_coef', verbose=1,\
                               save_best_only=True, mode='max', save_weights_only=True)

keras_callbacks = [checkpointer]

In [None]:
#%% Augment before training
upSample=4
aug_set=('rn', 'sig', 'log', 'inv', 'heq', 'rot', 'rs')
# aug_set=('rn', 'sig', 'log', 'inv', 'heq', 'rot')
%time imgs_aug, masks_aug, augs = mlearn.augmentImageData(imgs_train, masks_train,\
                                                          upsample=upSample, aug_set=aug_set)

imgs_aug = mlearn.prepare_imgs_for_unet(imgs_aug, unet)
masks_aug = mlearn.prepare_imgs_for_unet(masks_aug, unet)
print(f'Augmentation: {len(imgs_train)} --> {len(imgs_aug)}')

In [None]:
%%time
batch_size = 6 # For 1024 x 1024 images I can't help but use batch_size=6
epochs = 25
validation_split = 0.1
checkPoint = True

his = unet_fl.fit(imgs_aug, masks_aug, epochs=epochs, batch_size=batch_size,\
                   validation_split=validation_split, callbacks=keras_callbacks, verbose=1)


In [None]:
his = unet_fl.history.history
print(his.keys())
plt.figure(figsize=(15, 6))
plt.style.use(('seaborn-poster','fivethirtyeight', 'seaborn-white'))
plt.subplot(121)
plt.plot(his['val_dice_coef'],'.', label='validation set')
plt.plot(his['dice_coef'], label='training set')
plt.legend(fontsize=12)
plt.title('Dice coefficient', fontsize=14)

plt.subplot(122)
plt.plot(his['val_loss'],'.', label ='validation set')
plt.plot(his['loss'], label = 'training set')
plt.legend(fontsize=12)
plt.title('Foal loss ($\gamma = 2, unbalanced$)', fontsize=14);

In [None]:
imgs_prob = np.squeeze(unet_fl.predict(imgs_rs[..., None]))

In [None]:
alpha = 0.2
merge_ch = 0
fps = 50
inds = np.arange(450, 3000)
imgs_prob_255 = (imgs_prob*255).astype(int)
imgs_rs_rgb = np.array([gray2rgb(_, alpha=0.5) for _ in imgs_rs])

imgs_rs_rgb[..., merge_ch] = (alpha*imgs_rs_rgb[..., merge_ch] + (1-alpha)*imgs_prob_255).astype(int) 
ani =volt.animate_images(imgs_rs_rgb[inds], fps=fps, fig_size=(15, 15))
ani

In [None]:
%time imgs_fish = fsb.fish_imgs_from_raw(imgs_rs, unet_fl)[0]
%time midlines, inds_kept_midlines = fsb.track.midlines_from_binary_imgs(imgs_fish)
kappas = fsb.track.curvaturesAlongMidline(midlines, n=50)
tailAngles = np.cumsum(kappas, axis=0)
ta = hf.cleanTailAngles(tailAngles)[0]

In [None]:
#%% Plot tail angles

from matplotlib.colors import DivergingNorm
norm = DivergingNorm(0, vmin=-100, vmax=100)
fh, ax = plt.subplots(2,1, figsize=(20,10), sharex=True)

ax[0].imshow(ta[:, inds], aspect='auto', norm=norm, cmap='coolwarm', vmin=-100, vmax=100)
ax[0].set_yticks([0, 24, 49])
ax[0].set_yticklabels(['Head', 'Middle', 'Tail'])
ax[0].set_xticks([])
ax[0].set_title('Cumulative curvature along the tail')

ax[1].plot(ta[-1][inds])
ax[1].set_xlim(0, len(inds))
ax[1].set_xticks([0, len(inds)//2, len(inds)])
ax[1].set_xlabel('Image frame #')
ax[1].set_ylabel('Tail bend amplitude ($^o$)')
ax[1].set_title('Tail tail curvature timeseries');

## *Free swim behavior*