# Setup 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

#import os, sys

import csbdeep
from csbdeep import train
from csbdeep import nets
#from csbdeep import utils

# Load data

f = np.load('/home/uschmidt/research/csbdeep/csbdeep_experiments/experiments/tribolium/training/large_cx_0_1_2_cy_3_nz_16/train/data_label.npz')

np.savez('tribolium.npz',X=f['X'][:500],Y=f['Y'][:500])

In [None]:
ls -lh *.npz

In [None]:
(X,Y),data_val = train.load_data(
    #'/home/uschmidt/research/csbdeep/csbdeep_experiments/experiments/tribolium/training/large_cx_0_1_2_cy_3_nz_16/train/data_label.npz',
    #
    #'/home/uschmidt/research/csbdeep/csbdeep_experiments/experiments/tribolium/training/large_cx_0_1_2_cy_3_nz_16/test/data_label.npz',
    '/home/uschmidt/research/csbdeep/csbdeep_experiments/experiments/andi_tubulin/training/width_6_pert_vary/test/data_label.npz',
    #'/home/uschmidt/research/csbdeep/csbdeep_experiments/experiments/isonet_retina/training/subsample_10.20_augment/test/data_label.npz',
    validation_split=0.10,
    n_images=None,
)

In [None]:
n_train = X.shape[0]
n_val = 0 if data_val is None else data_val[0].shape[0]
image_size = X.shape[1:-1]
n_dim = len(image_size)
n_channel_in = X.shape[-1]
n_channel_out = Y.shape[-1]
print('# train images:\t', n_train)
print('# val images:\t', n_val)
print('%dD image size:\t'%n_dim, image_size)
print('Channels in:\t', n_channel_in)
print('Channels out:\t', n_channel_out)

In [None]:
from csbdeep.plot_utils import plot_some
# first row: input. second row: ground truth
plt.figure(figsize=(12,5))
_X,_Y = (X,Y) if data_val is None else data_val
plot_some(_X[:5],_Y[:5],pmax=99.5);

# Model

We will now define a neural network based on the deep learning library [Keras](https://keras.io).

## Input shape

**Option 1**: define neural network for a specific image size (that of the training images).

In [None]:
input_shape = X.shape[1:]
input_shape

**Option 2** *(better)*: model can be applied to other (compatible) image sizes after training.

In [None]:
input_shape = len(image_size)*(None,) + (n_channel_in,)
input_shape

## Build model

- based on u-net: ref
- resnet: ref
- see our supplement for details. table X gives the details for the models in our paper

In [None]:
probabilistic = True

In [None]:
model = nets.common_model(
    n_dim         = len(image_size),
    n_channel_out = n_channel_out,
    prob_out      = probabilistic,
    residual      = True and (n_channel_in == n_channel_out),
    # U-Net parameters:
    n_depth       = 2,
    kern_size     = 5 if n_dim==2 else 3,
    n_first       = 32,
)(input_shape)

In [None]:
# see doc
nets.common_model?

Alternative (equivalent) way to define model by shorthand name:

model = nets.common_model_by_name('resunet3p_1_3_16_1out')(input_shape)

In [None]:
# see doc
nets.common_model_by_name?

In [None]:
model.summary()

Use `csbdeep.nets.net_model` to built models with more flexibility, or build your own Keras model.

# Training

See [Keras Documentation](https://keras.io) for more information.

## Loss function
If you have chosen a network architecture for probabilistic prediction (`prob_out = True`), you need to use the probabilistic **`laplace`** loss. Otherwise, you can use the standard **`mse`** (*mean squared error*) or **`mae`** (*mean absolute error*) losses.

In [None]:
loss = 'laplace' if probabilistic else 'mse'

# check
if not( model.output_shape[-1] == (2 if probabilistic else 1)*Y.shape[-1] ):
    raise ValueError('number of input and output channels does not match.')

## Optimizer
The optimization algorithm to minimize the loss function during training. We have always used [Adam](https://keras.io/optimizers/#adam) with a learning rate of `0.0004`.

In [None]:
from keras.optimizers import Adam
optimizer = Adam(lr=0.0004)

## Prepare model for training and choose callbacks

The function `prepare_model` will [compile](https://keras.io/models/model/#compile) the model and return a list of [callbacks](https://keras.io/callbacks/) to be used during traning.

In [None]:
callbacks = train.prepare_model(model,optimizer,loss)

Furthermore, `prepare_model` offers the option to weigh *"foreground"* (fg) and *"background"* (bg) pixels differently in the loss functions (at the beginning of training). We found that this often leads to improved results, because there are typically many more bg than fg pixels.
To take advantage of this, you need to provide the argument `loss_bg_thresh` that defines the threshold between foreground and background pixels, and also provide the labeled data via argument `Y`.

The decay parameter `loss_bg_decay` specifies how long long the weighting should be active during training. We typically used a value of `0.06`, which means that the effect is effectively disabled after 10-15 epochs:

In [None]:
loss_bg_thresh = 0.4
loss_bg_decay = 0.02

n_epochs = 30

##
epochs = np.arange(1+n_epochs)
alphas = np.zeros(1+n_epochs)
alphas[0] = 1.0
for i in range(1,n_epochs):
    alphas[i] = alphas[i-1] / (1. + loss_bg_decay * i)

##
#half_life = 3
#gamma = np.log(2) / half_life
#new = np.exp(- gamma * (epochs))


freq = np.mean(Y > loss_bg_thresh)
w1 = 0.5 / (0.1 + (1 - freq))
w2 = 0.5 / (0.1 +      freq)

k1 = [(a * w1 + (1 - a)) for a in alphas]
k2 = [(a * w2 + (1 - a)) for a in alphas]

plt.figure(figsize=(13,4))
plt.subplot(121)
plt.plot(epochs,alphas,'.-');
plt.subplot(122)
plt.plot(epochs,k1,'.-',label='bg');
plt.plot(epochs,k2,'.-',label='fg');
plt.legend()



None;

# FIXME: replace `loss_bg_decay` with `loss_bg_halflife`

In [None]:
callbacks = train.prepare_model(model,optimizer,loss, loss_bg_thresh=0.3,loss_bg_decay=loss_bg_decay,Y=Y)

### [ModelCheckpoint](https://keras.io/callbacks/#modelcheckpoint) callback to save model during training

In [None]:
from keras.callbacks import ModelCheckpoint
callbacks.append(ModelCheckpoint('my_model_best.h5', save_best_only=True))

### [ReduceLROnPlateau](https://keras.io/callbacks/#reducelronplateau) callback to automatically lower learning

In [None]:
from keras.callbacks import ReduceLROnPlateau
callbacks.append(ReduceLROnPlateau(factor=0.5, patience=10, verbose=1))

### Monitor training progress with [TensorBoard](https://www.tensorflow.org/get_started/summaries_and_tensorboard) callback

In [None]:
from csbdeep.tf import MyTensorBoard
callbacks.append(MyTensorBoard(log_dir='./logs', n_images=3, write_images=True, prob_out=(loss=='laplace')))

Start TensorBoard e.g. with **`tensorboard --logdir=./logs --reload-interval=2`** and connect to [http://localhost:6006/]() with your browser.

### Start training with  [model.fit](https://keras.io/models/model/#fit)

Important parameters to choose are the number of `epochs` and the `batch_size`.

from keras_tqdm import TQDMNotebookCallback
from keras.callbacks import LambdaCallback
from IPython import display

def plot_callback(func,p=20):
    def plot_epoch_end(epoch,logs):
        if epoch == 0 or (epoch+1) % p == 0:
            plt.clf(); func(); # plt.title('epoch %d' % (epoch+1))
            display.clear_output(wait=True); display.display(plt.gcf())
    def clear(*args):
        plt.clf()
    return LambdaCallback(on_epoch_end=plot_epoch_end,on_train_end=clear)

from csbdeep.plot_utils import plot_foo    
def bar():
    plt.figure(figsize=(12,7))
    _P = model.predict(_X[:5])
    if probabilistic:
        _P = _P[...,:(_P.shape[-1]//2)]
    plot_foo(_X[:5],_Y[:5],_P,pmax=99.99);

callbacks = []
callbacks.append( TQDMNotebookCallback() )
callbacks.append( plot_callback(bar,5) )

In [None]:
history = model.fit(X,Y, validation_data=data_val,
                    epochs=100,
                    batch_size=16,
                    shuffle=True,
                    verbose=1,
                    callbacks=callbacks)

### Plot training history (available in TensorBoard even during training)

In [None]:
sorted(list(history.history.keys()))

In [None]:
from csbdeep.plot_utils import plot_history
plt.figure(figsize=(16,5))
if data_val is None:
    plot_history(history,['loss'],['mse','mae']);
else:
    plot_history(history,['loss','val_loss'],['mse','val_mse','mae','val_mae']);

### Show example results

In [None]:
plt.figure(figsize=(12,7))
_P = model.predict(_X[:5])
if probabilistic:
    _P = _P[...,:(_P.shape[-1]//2)]
plot_some(_X[:5],_Y[:5],_P,pmax=99.5);

### Save trained model

Save final model in the same way as in `ModelCheckpoint` callback above.

In [None]:
model.save('my_model_last.h5')

**Important**: export model to be used with CSBDeep **Fiji** plugins and **KNIME** workflows:

In [None]:
from csbdeep.tf import export_SavedModel
export_SavedModel(model,'my_model',format='zip')

## TODO: refer to https://github.com/CSBDeep/CSBDeep/wiki/Your-Model-in-Fiji

In [None]:
ls -oh

In [None]:
ls -1