<a href="https://colab.research.google.com/github/matjesg/DeepFLaSH/blob/master/tune_and_predict.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook is optmizied to be executed on Google Colab (https://colab.research.google.com).


*   Please read the instructions carefully.
*   Press the the *play* butten to execute the cells. It will show up between \[     \] on the left side of the code cells. 
*   Run the cells consecutively.

**Note:** You can predict your images without fine-tuning the model. For this, skip the Section 2  *Train model on the new data*.

>[Configuration](#scrollTo=rPRgwBd5tmGp)

>>[Set up Google Colab Environment](#scrollTo=Zm5wDM15c_cw)

>>[Choose base model](#scrollTo=c5ICgEznc_dB)

>[Train model on the new data](#scrollTo=6e59NnbQc_dL)

>>[Provide image training data](#scrollTo=V1588m6AtbhK)

>>>[Upload your images and masks (segmentation maps)](#scrollTo=5pLJdaY1Ymmh)

>>>[Use example images](#scrollTo=uB9vyeR3bVI0)

>>>[Plot images and masks](#scrollTo=8cD0AY6uZn71)

>>[Model training](#scrollTo=fV6yuPVuuL9P)

>>>[Check results on train data](#scrollTo=hv5bXNnJc_dX)

>>>[Plot all images and joined mask](#scrollTo=5n3FfCHnc_do)

>[Create segmentation maps for new images](#scrollTo=MAVMhDs1c_dg)

>>[Compare segmentation results](#scrollTo=Tg09Afx3Mwra)

>>[Save and download predicted masks](#scrollTo=NZ7kRhs9c_ds)



# Configuration
In this section, you will set up the training environment and choose your base model.

## Set up Google Colab Environment

In [None]:
!git clone https://github.com/matjesg/DeepFLaSH.git
import os
import sys
ROOT_DIR = os.path.abspath("DeepFLaSH")
sys.path.append(ROOT_DIR)\
    
import numpy as np
from unet import utils
from unet import colab_utils
from google.colab import files

## Choose base model

Look at the images and masks (segmentation maps) below. Which are more similar to yours?

* [cFOS_Wue](https://drive.google.com/open?id=1u1jAqxRpQh2hjE0W2vdHNCyhQsM5uAis): 
Trained on 36 image-mask pairs.
    
    <img src="https://raw.githubusercontent.com/matjesg/DeepFLaSH/master/assets/cFOS_Wue.png" width="250" height="250" alt="cFOS_Wue">
    <img src="https://raw.githubusercontent.com/matjesg/DeepFLaSH/master/assets/cFOS_Wue_mask.png" width="250" height="250" alt="cFOS_Wue_mask">

* [cFOS_Inns1](https://drive.google.com/open?id=1n6oGHaIvhbcBtzrkgWT6igg8ZXSOvE0D): Fine-tuned on [cFOS_Wue](https://drive.google.com/open?id=1u1jAqxRpQh2hjE0W2vdHNCyhQsM5uAis) 
with five image-mask pairs.

    <img src="https://raw.githubusercontent.com/matjesg/DeepFLaSH/master/assets/cFOS_Inns1.png" width="250" height="250" alt="cFOS_Inns1">
    <img src="https://raw.githubusercontent.com/matjesg/DeepFLaSH/master/assets/cFOS_Inns1_mask.png" width="250" height="250" alt="cFOS_Inns1_mask">

* [cFOS_Inns2](https://drive.google.com/open?id=1TGxZC93YUP1kp1xmboxl6fJEqU4oDRzP):
Fine-tuned on [cFOS_Wue](https://drive.google.com/open?id=1u1jAqxRpQh2hjE0W2vdHNCyhQsM5uAis) 
with five image-mask pairs.

    <img src="https://raw.githubusercontent.com/matjesg/DeepFLaSH/master/assets/cFOS_Inns2.png" width="250" height="250" alt="cFOS_Inns2">
    <img src="https://raw.githubusercontent.com/matjesg/DeepFLaSH/master/assets/cFOS_Inns2_mask.png" width="250" height="250" alt="cFOS_Inns2_mask">

* [cFOS_Mue](https://drive.google.com/open?id=1GFOsnLFY8nKDVcBTX7MvMTjoiYfhs91b):
Fine-tuned on [cFOS_Wue](https://drive.google.com/open?id=1u1jAqxRpQh2hjE0W2vdHNCyhQsM5uAis) 
with five image-mask pairs.

    <img src="https://raw.githubusercontent.com/matjesg/DeepFLaSH/master/assets/cFOS_Mue.png" width="250" height="250" alt="cFOS_Mue">
    <img src="https://raw.githubusercontent.com/matjesg/DeepFLaSH/master/assets/cFOS_Mue_mask.png" width="250" height="250" alt="cFOS_Mue_mask">

* [Parv](https://drive.google.com/open?id=1VtxyOXhuYVDAC8pkzx3SG9sZfvXqHDZI):
Trained on 36 images
    
    <img src="https://raw.githubusercontent.com/matjesg/DeepFLaSH/master/assets/Parv.png" width="250" height="250" alt="Parv">
    <img src="https://raw.githubusercontent.com/matjesg/DeepFLaSH/master/assets/Parv_mask.png" width="250" height="250" alt="Parv">

**Which model do you want to choose:**

* pretrained: 'cFOS_Wue', 'cFOS_Inns1', 'cFOS_Inns2', 'cFOS_Mue' or 'Parv'
* untrained model (no pretrained weights): 'new'

In [None]:
model = utils.load_unet('cFOS_Wue')

# Train model on the new data 
In this section, you can train your own model.

## Provide image training data
Either **upload your own images** or use the **transfer learning images from our repository**.

### Upload your images and masks (segmentation maps)
The training images and masks should reflect the diversity of your dataset.
* For fine-tuning a model, we recommend at least five image-mask pairs.
* For training a new model from scratch we recommend about 30 image-mask pairs.
* Make sure that both images and masks follow the same naming conventions, e.g. '01_img.tif' and '01_mask.tif'.
* Images will be resized to a 1024x1024 pixel resolution (greyscale, one channel).
* Typical filetypes are allowed (e.g., tif, png)

**Images:**

In [None]:
img_names, img_list = colab_utils.upload_files()

**Masks**:

In [None]:
msk_names, msk_list = colab_utils.upload_files()

### Use example images

In [None]:
img_names, img_list = colab_utils.load_samples(path = 'transfer_learning/train', suffix = 'new')
msk_names, msk_list = colab_utils.load_samples(path = 'transfer_learning/train', suffix = 'expert')

### Plot images and masks 

Check if images and masks are correctly assigned. If not, adjust your filenames and upload the images and masks again.

In [None]:
utils.plot_image_and_mask(img_names = img_names, img_list = img_list,
                          msk_names = msk_names, msk_list = msk_list)

## Model training

**Training duration (epochs)**

One epoch is when an entire (augemented) dataset is passed through the neural network for training. 
* We recommend about 50 epochs for fine-tuning and at least 100 epochs for a new model. 
* Choose a higher number if your images are very dissimilar to the sample images above.

In [None]:
epochs = 50

**Data Augmentation**
Create data augmentation generator for images and masks

In [None]:
train_generator = utils.create_generator(img_list, msk_list)

**Train model**

In [None]:
model.fit_generator(train_generator,
                    steps_per_epoch=int(np.ceil(len(img_list)/4.)),
                    epochs=epochs)

### Check results on train data

Predict masks and calculate Jaccard Similarity

In [None]:
# Predict masks with the U-net
pred_train = model.predict(np.asarray(img_list))
msk_train_list = [pred_train[i] for i in range(pred_train.shape[0])]

# Calculate pixelwise Jaccard Similarity
jac_pixel_results = utils.jaccard_sim(msk_list, msk_train_list)

# Calculte ROI wise Jaccard Similarity
regions_train_list = [utils.analyze_regions(msk, img) for msk, img in zip(msk_train_list, img_list)]
regions_exp_list = [utils.analyze_regions(msk, img) for msk, img in zip(msk_train_list, img_list)]
jaccard_roi_results = [utils.jaccard_images(a,b) for a,b in zip(regions_train_list,regions_exp_list)]

### Plot all images and joined mask

The joined mask consists of the manual/expert segmentation mask and U-net prediction.

Color code: 
- white = merge
- magenta = U-net only
- green = original/expert only

In [None]:
join_list = [utils.join_masks(msk_train_list[i], msk_list[i]) for i in range(len(msk_list))]
utils.plot_image_and_mask(img_names = img_names, img_list = img_list,
                          msk_names = jac_train, msk_list = join_list,
                          msk_head = 'Jaccard Similarity')
                    

# Create segmentation maps for new images
In this section, you can upload unlabelled images and predict the segmentation map (mask).

**Upload images**

*   Images will be resized to a 1024x1024 pixel resolution (greyscale, one channel).
*   Typical filetypes are allowed (e.g., tif, png)

In [None]:
img_new_names, img_new_list = colab_utils.upload_files()

**Predict masks (segmentation maps) with the U-net**

In [None]:
pred_new = model.predict(np.asarray(img_new_list))
pred_new_list = [pred_new[i] for i in range(pred_new.shape[0])]

**Plot results**
Look at the segmentation results of the U-net.

In [None]:
utils.plot_image_and_mask(img_names = img_new_names, img_list = img_new_list,
                    msk_names = img_new_names, msk_list = pred_new_list)

## Compare segmentation results
If you already have segmentation maps of the above images at your disposal, you can upload them here for comparison.

**Upload new segmentation maps (masks)**

* Make sure that both images and masks follow the same naming conventions, e.g. '01_img.tif' and '01_mask.tif'.
* Images will be resized to a 1024x1024 pixel resolution (greyscale, one channel).

In [None]:
msk_new_names, msk_new_list = colab_utils.upload_files()

**Plot comparison**

Color code: 
- white = merge
- magenta = U-net only
- green = original/expert only

In [None]:
# Calculate Jaccard Similarity
jac_test = utils.jaccard_sim(pred_new_list, msk_new_list)

join_new_list = [utils.join_masks(pred_new_list[i], msk_new_list[i]) for i in range(len(msk_list))]
utils.plot_image_and_mask(img_names = img_new_names, img_list = img_new_list,
                          msk_names = jac_test, msk_list = pred_new_list,
                          msk_name = 'Jaccard Similarity')

## Save and download predicted masks

In [None]:
utils.saveMasks(pred_new_list, img_new_names)
!zip -r masks.zip masks
files.download('masks.zip')