In [None]:
#%% Import relevant code
import os, sys, time
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import re as regex
import dask
import h5py
import pandas as pd
import keras
from keras.callbacks import ModelCheckpoint, TensorBoard
from keras import models
from skimage.util import montage
import glob

#--- Import my code
codeDir = r'V:/code/python/code'
sys.path.append(codeDir)
import apCode.FileTools as ft
import apCode.volTools as volt
from apCode.machineLearning import ml as mlearn
import apCode.SignalProcessingTools as spt
from apCode.machineLearning.unet import model
from apCode.behavior import FreeSwimBehavior as fsb
from apCode import geom
import apCode.hdf as hdf
from apCode import util
from rsNeuronsProj import util as rsp

#--- Setting seed for reproducability
seed = 143
np.random.seed = seed

plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42


#--- Auto-reload modules
try:
    if __IPYTHON__:
        get_ipython().magic('load_ext autoreload')
        get_ipython().magic('autoreload 2')
except NameError:
    pass

print(time.ctime())


## *Load a compiled U-net model and train on data, or load a pre-trained network*

In [None]:
dir_unet = r'Y:\Avinash\Ablations and Behavior'
learning_rate = 'default'

if learning_rate is 'default':
    print('Default learning rate')
    rmsprop = keras.optimizers.rmsprop()
else:
    rmsprop = keras.optimizers.rmsprop(learning_rate=learning_rate)
unet = model.get_unet(img_width=896, img_height=896, img_channels=1, optimizer=rmsprop, loss=model.focal_loss)


# file_weights = ft.findAndSortFilesInDir(dir_unet, ext='hdf', search_str='best_weights_fsb')[-1]
# unet.load_weights(os.path.join(dir_unet, file_weights))

# path_unet = glob.glob(os.path.join(dir_unet, 'trainedU*.h5'))[-1]
# print(path_unet)
# unet = mlearn.loadPreTrainedUnet(path_unet) # Load pre-trained 

## *Read images and masks for training*

In [None]:
%%time
dir_xls_train = r'Y:\Avinash\Ablations and Behavior'
file_xls_train = 'Paths_to_fish_training_images.xlsx'
sheet_name = 'Uncropped'
xls_train = pd.read_excel(os.path.join(dir_xls_train, file_xls_train), sheet_name=sheet_name)
xls_train = xls_train.loc[xls_train.exptType=='fsb']

imgDims = unet.input_shape[1:3]
imgs_train, masks_train = mlearn.read_training_images_and_masks(np.array(xls_train.pathToImages), 
                                                    np.array(xls_train.pathToMasks), imgDims=imgDims)
masks_train = (masks_train>0).astype(int)
print(f'Training on {imgs_train.shape[0]} of dimensions {imgs_train.shape[1:]}')

## *Evaluate pre-training performance*

In [None]:
metrics = unet.evaluate(imgs_train[..., None], masks_train[..., None], batch_size=6, verbose=1)
print(np.c_[unet.metrics_names, metrics])


## *Randomly select and plot images and masks for checking* 

In [None]:
inds = np.random.permutation(np.arange(imgs_train.shape[0]))
ind=inds[0]
m = montage((imgs_train[ind], masks_train[ind]), rescale_intensity=True, grid_shape=(1,2))
plt.figure(figsize=(20, 20))
plt.imshow(m, cmap='viridis')


## *Define Tensorboard as Checkpoint callbacks for later evaluation and storing best weights from training*

In [None]:
#%% Tensorflow callback for later evaluation
# dir_log = os.path.join(dir_unet, f'unet_training_logs_headFixed_{util.timestamp()}')
# tensorboard = TensorBoard(log_dir=dir_log, histogram_freq=1, write_images=True, batch_size=6)
# keras_callbacks = [tensorboard]


#%% Checkpointer callback for storing best weights
fp = os.path.join(dir_unet, f'best_weights_fsb_collimated_headFixed_{util.timestamp()}.hdf')
checkpointer = ModelCheckpoint(filepath=fp, monitor='val_dice_coef', verbose=1,\
                               save_best_only=True, mode='max', save_weights_only=True)

# keras_callbacks.append(checkpointer)
keras_callbacks = [checkpointer]

## *Augment images before training*

In [None]:
#%% Augment before training
upSample=5
aug_set=('rn', 'sig', 'log', 'inv', 'heq', 'rot', 'rs')
%time imgs_aug, masks_aug, augs = mlearn.augmentImageData(imgs_train, masks_train,\
                                                          upsample=upSample, aug_set=aug_set)

imgs_aug = mlearn.prepare_imgs_for_unet(imgs_aug, unet)
masks_aug = mlearn.prepare_imgs_for_unet(masks_aug, unet)
masks_aug= (masks_aug>0).astype(int)
print(f'Augmentation: {len(imgs_train)} --> {len(imgs_aug)}')

## *Evaluate pre-training performance on augmented images*

In [None]:
metrics = unet.evaluate(imgs_aug, masks_aug, batch_size=6, verbose=1)
print(np.c_[unet.metrics_names, metrics])


In [None]:
inds = np.random.permutation(np.arange(imgs_aug.shape[0]))
ind=inds[0]
m = montage((imgs_aug[ind][..., 0], masks_aug[ind][..., 0]), rescale_intensity=True, grid_shape=(1,2))
plt.figure(figsize=(20, 20))
plt.imshow(m, cmap='viridis')
plt.title(augs[ind], fontsize=30)


## *Train with callbacks*

In [None]:
%%time
batch_size = 6 # For 1024 x 1024 images I can't help but use this size
epochs = 150
validation_split = 0.1
checkPoint = True

his = unet.fit(imgs_aug, masks_aug, epochs=epochs, batch_size=batch_size,\
               validation_split=validation_split, callbacks=keras_callbacks, verbose=1)

# his = unet.fit(imgs_aug, masks_aug, epochs=epochs, batch_size=batch_size,\
#                validation_split=validation_split, verbose=1)


## *Load the best weight and save the model as HDF file*

In [None]:
#%% Load the best weights and save
fn = ft.findAndSortFilesInDir(dir_unet, search_str='best_weights_fsb_collimated_headFixed', ext='hdf')[-1]
unet.load_weights(os.path.join(dir_unet, fn))

#%% Save the U-net
# dir_unet = r'Y:\Avinash\Ablations and Behavior'
fn = f'trainedU_fsb_collimated_headFixed_{unet.input_shape[1]}x{unet.input_shape[2]}_{util.timestamp()}.h5'
unet.save(os.path.join(dir_unet, fn))
print(time.ctime())

## *Re-evaluate metrics after training*

In [None]:
metrics = unet.evaluate(imgs_aug, masks_aug, batch_size=6, verbose=1)
print(np.c_[unet.metrics_names, metrics])


## *Plot training metrics* 

In [None]:
his = unet.history.history
print(his.keys())
plt.style.use(('seaborn-poster', 'seaborn-white'))
plt.figure(figsize=(15, 6))
plt.subplot(121)
plt.plot(his['val_dice_coef'],'.', label='validation set')
plt.plot(his['dice_coef'], label='training set')
plt.legend(fontsize=12)
plt.title('Dice coefficient', fontsize=14)

plt.subplot(122)
plt.plot(his['val_loss'],'.', label ='validation set')
plt.plot(his['loss'], label = 'training set')
plt.legend(fontsize=12)
plt.title('Binary cross-entropy loss', fontsize=14)

## *Look at outputs/activations of different layers*

In [None]:
#%% Look at activations of layers
ind_layer = 35
ind_img = np.random.permutation(len(imgs_aug))[0]

img = imgs_aug[ind_img][None,...]
layer_outputs = [layer.output for layer in unet.layers]
layer = unet.layers[ind_layer]
activation_model = models.Model(inputs=unet.input, outputs=layer_outputs[ind_layer])
activations = np.transpose(np.squeeze(activation_model.predict(img)), (2, 0, 1))

img_rs = volt.img.resize(np.squeeze(img), activations.shape[-2:])
foo =  np.concatenate((img_rs[None,...], activations), axis=0)

nCols=8
nRows = (len(activations)//nCols)+1
m  = montage(foo, rescale_intensity=True, grid_shape=(nRows, nCols))
plt.figure(figsize=(20, 20*nCols/nRows))
plt.imshow(m, cmap='viridis')
plt.title(f'Example image (# {ind_img} top left) and outputs at layer {ind_layer}({layer.name})',\
          fontsize=20);

In [None]:
# filters, biases =[], []
# for layer in unet.layers:
#     if 'conv' in layer.name:
#         f, b = layer.get_weights()
#         filters.append(f)
# #         print(f.shape)
# f = np.transpose(np.squeeze(filters[0]), (2, 0, 1))
# m = montage(f, rescale_intensity=True)
# plt.figure(figsize=(20, 20))
# plt.imshow(m)

## *Read a contiguous set of unseen images for predicting with U-net*

In [None]:
# dir_imgs = r'Y:\Avinash\Head-fixed tail free\GCaMP imaging\2019-12-31\f2\001_h\behav\Autosave0_[00-11-1c-f1-75-10]_20191231_032423_AM'

dir_imgs = r'N:\Avinash\Ablations and Behavior\Ventral RS\20160523\Fish1_ctrl1\fastDir_08-18-16-182847\vib'
%time imgNames = ft.findAndSortFilesInDir(dir_imgs, ext='bmp')[:750]

imgs = volt.img.readImagesInDir(dir_imgs, imgNames=imgNames)
imgs_rs = volt.img.resize(imgs, unet.input_shape[1:3])

## *Generate probability maps*

In [None]:
imgs_prob = np.squeeze(unet.predict(imgs_rs[..., None], batch_size=6, verbose=1))

## *Make a movie to demonstrate segmentation*

In [None]:
alpha = 0.2
merge_ch = 0
fps = 50
cropSize = (256, 256)

from skimage.color import gray2rgb

fp = fsb.track.findFish(-imgs_rs*imgs_prob, back_img=None)
fp_interp = spt.interp.nanInterp1d(fp)
imgs_rs_crop = volt.img.cropImgsAroundPoints(imgs_rs, fp_interp, cropSize=cropSize)
imgs_prob_crop = volt.img.cropImgsAroundPoints(imgs_prob, fp_interp, cropSize=cropSize)

inds = np.arange(50, 750)
imgs_prob_255 = (imgs_prob_crop*255).astype(int)
imgs_rs_rgb = np.array([gray2rgb(_, alpha=0.5) for _ in imgs_rs_crop])

imgs_rs_rgb[..., merge_ch] = (alpha*imgs_rs_rgb[..., merge_ch] + (1-alpha)*imgs_prob_255).astype(int) 
ani =volt.animate_images(imgs_rs_rgb[inds], fps=fps, fig_size=(15, 15))
ani

## *From segmented fish images to tail curvature timeseries*

In [None]:
%time imgs_fish = fsb.fish_imgs_from_raw(imgs_rs, unet)[0]
%time midlines, inds_kept_midlines = fsb.track.midlines_from_binary_imgs(imgs_fish)
kappas = fsb.track.curvaturesAlongMidline(midlines, n=50)
tailAngles = np.cumsum(kappas, axis=0)
ta = hf.cleanTailAngles(tailAngles)[0]

## *Plot tail angles extracted from segmented fish*

In [None]:
#%% Plot tail angles

from matplotlib.colors import DivergingNorm
norm = DivergingNorm(0, vmin=-100, vmax=100)
fh, ax = plt.subplots(2,1, figsize=(20,10), sharex=True)

ax[0].imshow(ta[:, inds], aspect='auto', norm=norm, cmap='coolwarm', vmin=-100, vmax=100)
ax[0].set_yticks([0, 24, 49])
ax[0].set_yticklabels(['Head', 'Middle', 'Tail'])
ax[0].set_xticks([])
ax[0].set_title('Cumulative curvature along the tail')

ax[1].plot(ta[-1][inds])
ax[1].set_xlim(0, len(inds))
ax[1].set_xticks([0, len(inds)//2, len(inds)])
ax[1].set_xlabel('Image frame #')
ax[1].set_ylabel('Tail bend amplitude ($^o$)')
ax[1].set_title('Tail tail curvature timeseries');

# *Try Focal Loss*

In [None]:
#%% Instantiate U-net with focal loss specified during compilation
unet_fl = model.get_unet(img_width=896, img_height=896, img_channels=1, loss=model.focal_loss)

In [None]:
#%% Checkpointer callback for storing best weights
fp = os.path.join(dir_unet, f'best_weights_headFixed_{util.timestamp()}.hdf')
checkpointer = ModelCheckpoint(filepath=fp, monitor='val_dice_coef', verbose=1,\
                               save_best_only=True, mode='max', save_weights_only=True)

keras_callbacks = [checkpointer]

In [None]:
#%% Augment before training
upSample=4
aug_set=('rn', 'sig', 'log', 'inv', 'heq', 'rot', 'rs')
# aug_set=('rn', 'sig', 'log', 'inv', 'heq', 'rot')
%time imgs_aug, masks_aug, augs = mlearn.augmentImageData(imgs_train, masks_train,\
                                                          upsample=upSample, aug_set=aug_set)

imgs_aug = mlearn.prepare_imgs_for_unet(imgs_aug, unet)
masks_aug = mlearn.prepare_imgs_for_unet(masks_aug, unet)
print(f'Augmentation: {len(imgs_train)} --> {len(imgs_aug)}')

In [None]:
%%time
batch_size = 6 # For 1024 x 1024 images I can't help but use batch_size=6
epochs = 25
validation_split = 0.1
checkPoint = True

his = unet_fl.fit(imgs_aug, masks_aug, epochs=epochs, batch_size=batch_size,\
                   validation_split=validation_split, callbacks=keras_callbacks, verbose=1)


In [None]:
his = unet_fl.history.history
print(his.keys())
plt.figure(figsize=(15, 6))
plt.style.use(('seaborn-poster','fivethirtyeight', 'seaborn-white'))
plt.subplot(121)
plt.plot(his['val_dice_coef'],'.', label='validation set')
plt.plot(his['dice_coef'], label='training set')
plt.legend(fontsize=12)
plt.title('Dice coefficient', fontsize=14)

plt.subplot(122)
plt.plot(his['val_loss'],'.', label ='validation set')
plt.plot(his['loss'], label = 'training set')
plt.legend(fontsize=12)
plt.title('Foal loss ($\gamma = 2, unbalanced$)', fontsize=14);

In [None]:
imgs_prob = np.squeeze(unet_fl.predict(imgs_rs[..., None]))

In [None]:
alpha = 0.2
merge_ch = 0
fps = 50
inds = np.arange(450, 3000)
imgs_prob_255 = (imgs_prob*255).astype(int)
imgs_rs_rgb = np.array([gray2rgb(_, alpha=0.5) for _ in imgs_rs])

imgs_rs_rgb[..., merge_ch] = (alpha*imgs_rs_rgb[..., merge_ch] + (1-alpha)*imgs_prob_255).astype(int) 
ani =volt.animate_images(imgs_rs_rgb[inds], fps=fps, fig_size=(15, 15))
ani

In [None]:
%time imgs_fish = fsb.fish_imgs_from_raw(imgs_rs, unet_fl)[0]
%time midlines, inds_kept_midlines = fsb.track.midlines_from_binary_imgs(imgs_fish)
kappas = fsb.track.curvaturesAlongMidline(midlines, n=50)
tailAngles = np.cumsum(kappas, axis=0)
ta = hf.cleanTailAngles(tailAngles)[0]

In [None]:
#%% Plot tail angles

from matplotlib.colors import DivergingNorm
norm = DivergingNorm(0, vmin=-100, vmax=100)
fh, ax = plt.subplots(2,1, figsize=(20,10), sharex=True)

ax[0].imshow(ta[:, inds], aspect='auto', norm=norm, cmap='coolwarm', vmin=-100, vmax=100)
ax[0].set_yticks([0, 24, 49])
ax[0].set_yticklabels(['Head', 'Middle', 'Tail'])
ax[0].set_xticks([])
ax[0].set_title('Cumulative curvature along the tail')

ax[1].plot(ta[-1][inds])
ax[1].set_xlim(0, len(inds))
ax[1].set_xticks([0, len(inds)//2, len(inds)])
ax[1].set_xlabel('Image frame #')
ax[1].set_ylabel('Tail bend amplitude ($^o$)')
ax[1].set_title('Tail tail curvature timeseries');

## *Free swim behavior*