In [None]:
%matplotlib inline
import matplotlib
backend = matplotlib.get_backend()
import matplotlib.pyplot as plt

import numpy as np
import importlib

from pathlib import Path

import os
# This line will supress the TF output 
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import tensorflow as tf
from midap.data.tf_pipeline import TFPipeFamilyMachine
matplotlib.use(backend)

## Custom training with MIDAP

This notebook can be used to train and finetune either standard models from the pipeline or custom models. This is essentially a step by step version of `train.py` in the parent directory of this notebook and you should use the script to actually train the model. All parameters set in the cell below can be set as commandline arguments in the training script. Run `python train.py --help` to get the full signature. 


**IMPORTANT:**

We adapted some parameter such that the notebook runs fast for demonstration purposes. Running the notebook in the default state will not lead to properly trained models. If you want to properly train a model yourself you need to:

1. Increase the number of epochs below (50 is recommened)
2. Take the whole datasat for the trainig, i.e. increase `n_take`

In [None]:
#########################
# define the parameters #
#########################

# Required parameters
#####################

# where to save the results
save_path = "./my_models/model1"
# You can provide a list of files or a glob(star) expression with the pathlib
# E.g. this loads all images of all bacteria in the PH channel
eval_files = list(Path("../midap_training/").glob("**/PH/**/*raw.tif"))
# the size of the cutouts for the training
image_size = (128, 128, 1)

# Model choice
##############

# Choose a model to train
custom_model = None # Use the standard UNet
# custom_model = "CustomUNet" # Use a model defined in ../custom_model.py

# Choose a restore path for fine tuning. The stored weights have to be compatible with
# the chosen model. E.g. for a classic UNet (i.e. "custom_model = None" above) all 
# pretrained weights from the pipeline can be used, e.g.:
# restore_path = "../../model_weights/model_weights_family_mother_machine/model_weights_ZF270g.h5"
restore_path = None # Start from scratch (default)

# Choices for the training
##########################

# batch size for the training
batch_size = 2
# Number of epochs
epochs = 5 # (default 50)
# save the full model, this should only be true for custom models that you want to add to the pipeline
save_model = False
# Tensorboard call back. You can provide a logdir to create a tenorboard log to see the performance of the training
# you can view it inside a jupyter notebook or you can start a new process from the terminal via
# python -m tensorboard.main --logdir <logdir>
# Note that running it in a separate process is recommended and you should not run it on the cluster
# It is good practice to set it to either None (no callback) or ./logs (create logs dir for the logs)
tfboard_logdir = "./logs" # None

In [None]:
# load the training data 
########################

tf_pipe = TFPipeFamilyMachine(eval_files, image_size=image_size, batch_size=batch_size)

In [None]:
# Load the model
################

if custom_model is None:
    # load the standard UNet
    from midap.networks.unets import UNetv1 as ModelClass
else:
    # Load the class that was defined above
    spec = importlib.util.spec_from_file_location("module", "../custom_model.py")
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    ModelClass = getattr(module, "CustomUNet")
    
# initialize the model
model = ModelClass(input_size=image_size, dropout=0.5)

In [None]:
# Restore the weights
#####################

if restore_path is not None:
    model.load_weights(restore_path)

In [None]:
# Callback for logging including TFBoard
########################################

callbacks = [tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)]
if tfboard_logdir is not None:
    callbacks.append(tf.keras.callbacks.TensorBoard(log_dir=tfboard_logdir))

In [None]:
# Fit the model
###############

# Fit the model
n_take = 10
history= model.fit(x=tf_pipe.dset_train.take(n_take),
                   epochs=epochs,
                   validation_data=tf_pipe.dset_val,
                   callbacks=callbacks,
                   shuffle=True)

In [None]:
# Save the results
##################

if save_model:
    model.save(save_path, save_format="h5")
else:
    model.save_weights(save_path, save_format="h5")

### Tensorboard callback

If you set `tfboard_logdir = "./logs"` above, you can visualize the training and the model using Tensorboard. Note that it is recommended to start it in a different process via the terminal. However, it can be used inside the jupyter notebook.

**IMPORTANT**

1. If the jupyter notebook is running in an environment that does not contain a `tensorboard` executable or if jupyter was started from one environment and the kernel was changed, you might need to explicitely set the executable. The logic is shown in the cells below.
2. If you set `tfboard_logdir = "bla"`, i.e. something else but `./logs`, change the `--logdir` argument of the magic function below

In [None]:
# set the executable if necessary
os.environ["TENSORBOARD_BINARY"] = "/home/janis/anaconda3/envs/midap/bin/tensorboard"

if tfboard_logdir is not None:
    %load_ext tensorboard
    # change logs to custom directory if chosen
    %tensorboard --logdir logs