<a href="https://colab.research.google.com/github/KPAllard/BDC---PTA/blob/master/FineTune_DistNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Initialize GPU

The next cell is check that GPU is not a K80 for a faster training. If K80, choose Execution>Manage Session, stop the current session, refresh page and re-run the cell

In [1]:
!nvidia-smi 

Sat Dec  5 15:19:40 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.45.01    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P0    25W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# Downaload Dataset & model weights
Set the 5 first variables of the next cell.
The .h5 dataset should follow [this file architecture](https://github.com/jeanollion/dataset_iterator). See [this page](https://github.com/jeanollion/bacmman/wiki/FineTune-DistNet) to generate is with BACMMAN software.

Follow instructions in the output of the cell to mount the google drive that contains the .h5 dataset file. 

Model weights will be downloaded. 

In [None]:
drive_folder = "Segmentation" # folder containing dataset in google drive
dataset_name = "bacteriaSegTrackEDM.h5" # name of the hdf5 dataset file
training_selection_name = "ds_noError/" # name of the training selection
validation_selection_name = "ds_noError_test/" # name of the validation selection
saved_model_file = "distNet_fine-tuned.zip" # name of the exported zipped model file

raw_feature_name = "/raw" # raw input channel
label_feature_name = "/regionLabels" # bacteria labels
pre_label_feature_name = "/prevRegionLabels" # label of previous bacteria

# mount google drive: follow the link, choose the account that contains the dataset and paste token in the displayed field
from google.colab import drive
import os
import sys
def mount_drive():
  drive.mount('/content/driveDL', force_remount=True)
  os.chdir("/content/driveDL/My Drive/"+drive_folder)
mount_drive()

# install dependencies
!pip install --upgrade h5py==2.9
!pip install edt
!pip install git+https://github.com/jeanollion/distnet.git
!pip install git+https://github.com/jeanollion/dataset_iterator.git
# copy dataset locally
!cp "$dataset_name" "/home/bacteriaSegTrack.h5"
# download weights
!gdown "https://drive.google.com/uc?export=download&id=1-3VmrlUINU-OpC1JnaNA3iM-dztltDCJ" -O "/home/distNet_weights.h5"

# Initialize dataset iterator & data augmentation & model

- Data augmentation parameters (parameters of `ImageDataGeneratorMM` ) should be modified according to the dataset. Use the command in the visualize section to check that data augmentation do not produce unrealistic images, in particular the aspect ratio limits (so that transformed cells are not too thin / not to fat) 

- The amount of data should be carefully chosen, according to the task. It should contain all the diversity that has to be processed, as well as a few empty channels. 

- We achieved similar performance as on the orginal training dataset using ~2500 microchannels for training and ~750 for validation. 

- Training and validation sets should be chosen among distinct microchannels to avoid bias. 

- Starting learning rate should also be carefully chosen. We achieved good results using 1e-5

In [None]:
import h5py
from dataset_iterator import DyIterator
from distnet.data_generator import ImageDataGeneratorMM
import distnet.keras_models as km
from tensorflow.keras.optimizers import Adam
from distnet.utils.losses import weighted_loss_by_category
from tensorflow.keras.losses import sparse_categorical_crossentropy, mean_absolute_error, mean_squared_error


label_datagen = ImageDataGeneratorMM(
    horizontal_flip=True,
    width_shift_range=5, height_shift_range=40,
    width_zoom_range=[0.5, 1.5],
    height_zoom_range=[0.8, 1.2],
    min_zoom_aspectratio=0.5,
    max_zoom_aspectratio=2.5,
    rotation_range=3,
    shear_range=20,
    bacteria_swim_distance=40,
    perform_illumination_augmentation=False,
    interpolation_order=0
)
image_datagen = ImageDataGeneratorMM(interpolation_order=1, perform_illumination_augmentation=True)
params = dict(dataset='/home/bacteriaSegTrack.h5', 
              channel_keywords=[raw_feature_name, label_feature_name, pre_label_feature_name], # channel keyword must correspond to the name of the extracted features
              image_data_generators=[image_datagen, label_datagen, label_datagen],
              output_channels=[1, 2],
              mask_channels=[1, 2],
              channels_prev=[True, True, False],
              channels_next=[False]*3,
              compute_edm="all",
              erase_cut_cell_length=20,
              closed_end=True,
              return_categories=True,
              batch_size=64,
              perform_data_augmentation=True,
              shuffle=True)

train_it = DyIterator(group_keyword=training_selection_name, **params)
test_it = DyIterator(group_keyword=validation_selection_name, **params) # alternatively one can use train_test_split method to split the dataset. 

dy_loss = mean_absolute_error
edm_loss = mean_squared_error
class_loss=weighted_loss_by_category(sparse_categorical_crossentropy, [1, 1, 5, 5])

model = km.get_distnet_model()

model.compile(optimizer=Adam(1e-5), loss=[dy_loss, class_loss, edm_loss], loss_weights=[0.5, 1, 1]) # training from scratch can start at a higher learning rate such as 2e-4
model.load_weights("/home/distNet_weights.h5")

# Perform fine-tuning
Run this cell to perform fine-tuning. Tab should stay opened during the whole training. 

Intermediate weights are stored in google drive in the file *distNet_fineTuned_cp.h5*. If the runtime was disconnected, they can be loaded with the command: `model.load_weights("distNet_fineTuned_cp.h5")`


In [None]:
from tensorflow.keras.callbacks import ReduceLROnPlateau
import numpy as np
from distnet.utils import PatchedModelCheckpoint

train_it._close_datasetIO()
test_it._close_datasetIO()
checkpoint = PatchedModelCheckpoint("/home/model_cpV_{epoch:02d}.h5", filepath_dest="distNet_fineTuned_cp.h5", timeout_function=mount_drive, monitor='val_loss', verbose=1, save_best_only=True, save_weights_only=True)
lr_schedule = ReduceLROnPlateau(min_lr=1e-6, factor=0.5, patience=5, verbose=1, min_delta=0.001)

model.fit_generator(train_it, epochs=30, validation_data=test_it, callbacks=[lr_schedule, checkpoint])

# Save model for prediction

Open the following notebook in colab in order to convert the .h5 weight file into a model that can be used by tensorflow for prediction: 
 [![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1LVrBoazWq9xcrt2Ea3ArILvMApfTyvwp)

# Visualize

In [None]:
import matplotlib.pyplot as plt
import random
import numpy as np

batch_size = test_it.batch_size
test_it.batch_size = 10
x_test, [dy_test, category_test, edm_test] = test_it.next()
[dy_pred, category_pred, edm_pred] = model.predict(x_test)
plot_category = False
plt.figure(figsize=(22, 8))
n_im_show = 2 + 2 + 2 + (5 if plot_category else 0 ) + 1
n_subdiv = n_im_show * test_it.batch_size // 2
for i in range(0, test_it.batch_size):
    j=1
    plt.subplot(2, n_subdiv, n_im_show*i + j)
    plt.imshow(x_test[i,:,:,0], cmap="gray")
    plt.axis("off")
    plt.title("Prev")
    j+=1
    plt.subplot(2, n_subdiv, n_im_show*i + j)
    plt.imshow(x_test[i,:,:,1], cmap="gray")
    plt.axis("off")
    plt.title("Cur")
    j+=1
    
    # dy
    vdis = max(abs(dy_test[i,:,:, 0].min()),abs(dy_test[i,:,:, 0].max()))
    plt.subplot(2, n_subdiv, n_im_show*i + j)
    plt.imshow(dy_test[i,:,:, 0], cmap="bwr", vmin=-vdis, vmax=vdis)
    plt.axis("off")
    plt.title("GT")
    j+=1
    plt.subplot(2, n_subdiv, n_im_show*i + j)
    plt.imshow(dy_pred[i,:,:, 0], cmap="bwr", vmin=-vdis, vmax=vdis)
    plt.axis("off")
    plt.title("dY")
    j+=1
    
    # edm current
    vedm = max(0, edm_pred[i,:,:, 1].max())
    plt.subplot(2, n_subdiv, n_im_show*i + j)
    plt.imshow(edm_pred[i,:,:,0], cmap="gray", vmin=0, vmax=vedm)
    plt.axis("off")
    plt.title("Prev")
    j+=1
    plt.subplot(2, n_subdiv, n_im_show*i + j)
    plt.imshow(edm_pred[i,:,:,1], cmap="gray", vmin=0, vmax=vedm)
    plt.axis("off")
    plt.title("Cur")
    j+=1

    if plot_category:
      plt.subplot(2, n_subdiv, n_im_show*i + j)
      plt.imshow(category_test[i,:,:,0], cmap="rainbow", vmin=0, vmax=3)
      plt.axis("off")
      j+=1
      for c in range(4):
        plt.subplot(2, n_subdiv, n_im_show*i + j)
        plt.imshow(category_pred[i,:,:,c], cmap="gray", vmin=0, vmax=1)
        plt.axis("off")
        j+=1

plt.show()
test_it.batch_size = batch_size