# **U-Net (2D) for multilabel segmentation (semantic segmentation)**
---

<font size = 4>U-Net is an encoder-decoder network architecture originally used for image segmentation, first published by [Ronneberger *et al.*](https://arxiv.org/abs/1505.04597). The first half of the U-Net architecture is a downsampling convolutional neural network which acts as a feature extractor from input images. The other half upsamples these results and restores an image by combining results from downsampling with the upsampled images.

The main difference between this U-Net and the original one, is that the output is a semantic mask rather than a binary mask: It allows the segmentation of different kind of structures, objects or tissues present in the image.  by labelling pixels with different values (0, 1, 2, ...) rather than just 0 and 1 (binary segmentation). So, to use this notebook you need to provide a mask image where each type of label has a different pixel value rather than just 0 and 1. For example 0 = background, 1 = cytoplasm, 2 = lumen and 3 = nuclei.

<font size = 4> **This particular notebook enables image segmentation of 2D dataset. If you are interested in 3D dataset, you should use the 3D U-Net notebook instead.**

---
<font size = 4>*Disclaimer*:

<font size = 4>This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.

<font size = 4>This notebook is largely based on the papers:

<font size = 4>**U-Net: Convolutional Networks for Biomedical Image Segmentation** by Ronneberger *et al.*  published on arXiv in 2015 (https://arxiv.org/abs/1505.04597)

<font size = 4>and

<font size = 4>**U-Net: deep learning for cell counting, detection, and morphometry** by Thorsten Falk *et al.* in Nature Methods 2019
(https://www.nature.com/articles/s41592-018-0261-2)
And source code found in: https://github.com/zhixuhao/unet by *Zhixuhao*

<font size = 4>**Please also cite this original paper when using or developing this notebook.**

# **How to use this notebook?**

---

<font size = 4>Video describing how to use our notebooks are available on youtube:
  - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook
  - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook



---
### **Structure of a notebook**

<font size = 4>The notebook contains two types of cell:  

<font size = 4>**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.

<font size = 4>**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.

---
### **Table of contents, Code snippets** and **Files**

<font size = 4>On the top left side of the notebook you find three tabs which contain from top to bottom:

<font size = 4>*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.

<font size = 4>*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.

<font size = 4>*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here.

<font size = 4>**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.

<font size = 4>**Note:** The "sample data" in "Files" contains default files. Do not upload anything in here!

---
### **Making changes to the notebook**

<font size = 4>**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.

<font size = 4>To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).
You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment.

# **0. Before getting started**
---

<font size = 4>Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.

<font size = 4>For U-Net to train, **it needs to have access to a paired training dataset corresponding to images and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki

<font size = 4>**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.

<font size = 4>Additionally, the corresponding Training_source and Training_target files need to have **the same name**.

<font size = 4>Here's a common data structure that can work:
*   Experiment A
    - **Training dataset**
      - Training_source
        - img_1.tif, img_2.tif, ...
      - Training_target
        - img_1.tif, img_2.tif, ...
    - **Quality control dataset**
     - Training_source
        - img_1.tif, img_2.tif
      - Training_target
        - img_1.tif, img_2.tif
    - **Data to be predicted**
    - **Results**

---
<font size = 4>**Important note**

<font size = 4>- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.

<font size = 4>- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.

<font size = 4>- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.
---

# **1. Install U-Net dependencies**
---


## **1.1. Install key dependencies**
---
<font size = 4>

In [None]:
#@markdown ##Play to install U-Net dependencies
# Install packages which are not included in Google Colab
!pip install data
!pip install -q tifffile # contains tools to operate tiff-files
!pip install -q wget
!pip install -q fpdf2
!pip install -q PTable # Nice tables
!pip install -q zarr
!pip install -q imagecodecs
!pip install -q bioimageio.core==0.6.9


## **1.2. Restart your runtime**
---
<font size = 4>


**<font size = 4> Ignore the following message error message. Your Runtime has automatically restarted. This is normal.**

<img width="40%" alt ="" src="https://github.com/HenriquesLab/ZeroCostDL4Mic/raw/master/Wiki_files/session_crash.png"><figcaption>  </figcaption>

## **1.3. Load key dependencies**
---
<font size = 4>

In [None]:
from __future__ import print_function
Notebook_version = '2.1.3'
Network = 'U-Net (2D) multilabel'


from builtins import any as b_any

def get_requirements_path():
    # Store requirements file in 'contents' directory
    current_dir = os.getcwd()
    dir_count = current_dir.count('/') - 1
    path = '../' * (dir_count) + 'requirements.txt'
    return path

def filter_files(file_list, filter_list):
    filtered_list = []
    for fname in file_list:
        if b_any(fname.split('==')[0] in s for s in filter_list):
            filtered_list.append(fname)
    return filtered_list

def build_requirements_file(before, after):
    path = get_requirements_path()

    # Exporting requirements.txt for local run
    !pip freeze > $path

    # Get minimum requirements file
    df = pd.read_csv(path)
    mod_list = [m.split('.')[0] for m in after if not m in before]
    req_list_temp = df.values.tolist()
    req_list = [x[0] for x in req_list_temp]

    # Replace with package name and handle cases where import name is different to module name
    mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]
    mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list]
    filtered_list = filter_files(req_list, mod_replace_list)

    file=open(path,'w')
    for item in filtered_list:
        file.writelines(item)

    file.close()

import sys
before = [str(m) for m in sys.modules]

#@markdown ##Load key U-Net dependencies

#As this notebokk depends mostly on keras which runs a tensorflow backend (which in turn is pre-installed in colab)
#only the data library needs to be additionally installed.
import tensorflow as tf
# print(tensorflow.__version__)
# print("Tensorflow enabled.")


# Keras imports
from tensorflow.keras import models
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, concatenate, UpSampling2D
from tensorflow.keras.optimizers import Adam
# from keras.callbacks import ModelCheckpoint, LearningRateScheduler, CSVLogger # we currently don't use any other callbacks from ModelCheckpoints
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img
from tensorflow.keras import backend as keras
from tensorflow.keras.callbacks import Callback

# General import
import numpy as np
import pandas as pd
import os
import glob
from skimage import img_as_ubyte, io, transform
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib.pyplot import imread
from pathlib import Path
import shutil
import random
import time
import csv
import sys
from math import ceil
from fpdf import FPDF, HTMLMixin
from pip._internal.operations.freeze import freeze
import subprocess
# Imports for QC
from PIL import Image
from scipy import signal
from scipy import ndimage
from sklearn.linear_model import LinearRegression
from skimage.util import img_as_uint
from skimage.metrics import structural_similarity
from skimage.metrics import peak_signal_noise_ratio as psnr

# For sliders and dropdown menu and progress bar
from ipywidgets import interact
import ipywidgets as widgets
# from tqdm import tqdm
from tqdm.notebook import tqdm

from sklearn.feature_extraction import image
from skimage import img_as_ubyte, io, transform
from skimage.util.shape import view_as_windows

from datetime import datetime


# Suppressing some warnings
import warnings
warnings.filterwarnings('ignore')

# BioImage Model Zoo
from shutil import rmtree
import bioimageio.spec.model.v0_5 as bioimageio_spec
from bioimageio.spec import save_bioimageio_package
import bioimageio.core as bioimageio_core
from zipfile import ZipFile
import requests
from bioimageio.spec.pretty_validation_errors import (
    enable_pretty_validation_errors_in_ipynb,
)
enable_pretty_validation_errors_in_ipynb()

#Create a variable to get and store relative base path
base_path = os.getcwd()

def create_patches(Training_source, Training_target, patch_width, patch_height, min_fraction):
  """
  Function creates patches from the Training_source and Training_target images.
  The steps parameter indicates the offset between patches and, if integer, is the same in x and y.
  Saves all created patches in two new directories in the base_path folder.

  Returns: - Two paths to where the patches are now saved
  """
  DEBUG = False

  Patch_source = os.path.join(base_path,'img_patches')
  Patch_target = os.path.join(base_path,'mask_patches')
  Patch_rejected = os.path.join(base_path,'rejected')

  #Here we save the patches, in the /content directory as they will not usually be needed after training
  if os.path.exists(Patch_source):
    shutil.rmtree(Patch_source)
  if os.path.exists(Patch_target):
    shutil.rmtree(Patch_target)
  if os.path.exists(Patch_rejected):
    shutil.rmtree(Patch_rejected)

  os.mkdir(Patch_source)
  os.mkdir(Patch_target)
  os.mkdir(Patch_rejected) #This directory will contain the images that have too little signal.

  patch_num = 0
  Training_source_list = [f for f in os.listdir(Training_source) if not f.startswith(".")]
  for file in tqdm(Training_source_list):

    img = io.imread(os.path.join(Training_source, file))
    mask = io.imread(os.path.join(Training_target, file),as_gray=True)

    if DEBUG:
      print(file)
      print(img.dtype)

    # Using view_as_windows with step size equal to the patch size to ensure there is no overlap
    patches_img = view_as_windows(img, (patch_width, patch_height), (patch_width, patch_height))
    patches_mask = view_as_windows(mask, (patch_width, patch_height), (patch_width, patch_height))

    patches_img = patches_img.reshape(patches_img.shape[0]*patches_img.shape[1], patch_width,patch_height)
    patches_mask = patches_mask.reshape(patches_mask.shape[0]*patches_mask.shape[1], patch_width,patch_height)

    if DEBUG:
      print(all_patches_img.shape)
      print(all_patches_img.dtype)

    for i in range(patches_img.shape[0]):
      img_save_path = os.path.join(Patch_source,'patch_'+str(patch_num)+'.tif')
      mask_save_path = os.path.join(Patch_target,'patch_'+str(patch_num)+'.tif')
      patch_num += 1

      # if the mask conatins at least 2% of its total number pixels as mask, then go ahead and save the images
      pixel_threshold_array = sorted(patches_mask[i].flatten())
      if pixel_threshold_array[int(round((len(pixel_threshold_array)-1)*(1-min_fraction)))]>0:
        io.imsave(img_save_path, img_as_ubyte(normalizeMinMax(patches_img[i])))
        io.imsave(mask_save_path, patches_mask[i])
      else:
        io.imsave(Patch_rejected+'/patch_'+str(patch_num)+'_image.tif', img_as_ubyte(normalizeMinMax(patches_img[i])))
        io.imsave(Patch_rejected+'/patch_'+str(patch_num)+'_mask.tif', patches_mask[i])

  return Patch_source, Patch_target


def estimatePatchSize(data_path, max_width = 512, max_height = 512):

  files = [f for f in os.listdir(data_path) if not f.startswith(".")]

  # Get the size of the first image found in the folder and initialise the variables to that
  n = 0
  while os.path.isdir(os.path.join(data_path, files[n])):
    n += 1
  (height_min, width_min) = Image.open(os.path.join(data_path, files[n])).size

  # Screen the size of all dataset to find the minimum image size
  for file in files:
    if not os.path.isdir(os.path.join(data_path, file)):
      (height, width) = Image.open(os.path.join(data_path, file)).size
      if width < width_min:
        width_min = width
      if height < height_min:
        height_min = height

  # Find the power of patches that will fit within the smallest dataset
  width_min, height_min = (fittingPowerOfTwo(width_min), fittingPowerOfTwo(height_min))

  # Clip values at maximum permissible values
  if width_min > max_width:
    width_min = max_width

  if height_min > max_height:
    height_min = max_height

  return (width_min, height_min)

def fittingPowerOfTwo(number):
  n = 0
  while 2**n <= number:
    n += 1
  return 2**(n-1)

## TODO: create weighted CE for semantic labels
def getClassWeights(Training_target_path):

  Mask_dir_list = [f for f in os.listdir(Training_target_path) if not f.startswith(".")]

  number_of_dataset = len(Mask_dir_list)

  class_count = np.zeros(2, dtype=int)
  for i in tqdm(range(number_of_dataset)):
    mask = io.imread(os.path.join(Training_target_path, Mask_dir_list[i]))
    mask = normalizeMinMax(mask)
    class_count[0] += mask.shape[0]*mask.shape[1] - mask.sum()
    class_count[1] += mask.sum()

  n_samples = class_count.sum()
  n_classes = 2

  class_weights = n_samples / (n_classes * class_count)
  return class_weights

def weighted_binary_crossentropy(class_weights):

    def _weighted_binary_crossentropy(y_true, y_pred):
        binary_crossentropy = keras.binary_crossentropy(y_true, y_pred)
        weight_vector = y_true * class_weights[1] + (1. - y_true) * class_weights[0]
        weighted_binary_crossentropy = weight_vector * binary_crossentropy

        return keras.mean(weighted_binary_crossentropy)

    return _weighted_binary_crossentropy


def save_augment(datagen,orig_img,dir_augmented_data = base_path + "/augment"):
  """
  Saves a subset of the augmented data for visualisation, by default in /content.

  This is adapted from: https://fairyonice.github.io/Learn-about-ImageDataGenerator.html

  """
  try:
    os.mkdir(dir_augmented_data)
  except:
        ## if the preview folder exists, then remove
        ## the contents (pictures) in the folder
    dir_augmented_data_list = [f for f in os.listdir(dir_augmented_data) if not f.startswith(".")]
    for item in dir_augmented_data_list:
      os.remove(dir_augmented_data + "/" + item)

    ## convert the original image to array
  x = img_to_array(orig_img)
    ## reshape (Sampke, Nrow, Ncol, 3) 3 = R, G or B
    #print(x.shape)
  x = x.reshape((1,) + x.shape)
    #print(x.shape)
    ## -------------------------- ##
    ## randomly generate pictures
    ## -------------------------- ##
  i = 0
    #We will just save 5 images,
    #but this can be changed, but note the visualisation in 3. currently uses 5.
  Nplot = 5
  for batch in datagen.flow(x,batch_size=1,
                            save_to_dir=dir_augmented_data,
                            save_format='tif',
                            seed=42):
    i += 1
    if i > Nplot - 1:
      break

# Generators
def buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, subset, batch_size, target_size, validatio_split):
  '''
  Can generate image and mask at the same time use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same

  datagen: ImageDataGenerator
  subset: can take either 'training' or 'validation'
  '''

  # Build the dict for the ImageDataGenerator
  # non_aug_args = dict(width_shift_range = 0,
  #                     height_shift_range = 0,
  #                     rotation_range = 0, #90
  #                     zoom_range = 0,
  #                     shear_range = 0,
  #                     horizontal_flip = False,
  #                     vertical_flip = False,
  #                     fill_mode = 'reflect')
  # default params of data generator is without augmentation
  mask_load_gen = ImageDataGenerator(dtype='uint8', validation_split=validatio_split)
  image_load_gen = ImageDataGenerator(dtype='float32', validation_split=validatio_split, preprocessing_function = normalizePercentile)

  image_generator = image_load_gen.flow_from_directory(
        os.path.dirname(image_folder_path),
        classes = [os.path.basename(image_folder_path)],
        class_mode = None,
        color_mode = "grayscale",
        target_size = target_size,
        batch_size = batch_size,
        subset = subset,
        interpolation = "bicubic",
        seed = 1)
  mask_generator = mask_load_gen.flow_from_directory(
        os.path.dirname(mask_folder_path),
        classes = [os.path.basename(mask_folder_path)],
        class_mode = None,
        color_mode = "grayscale",
        target_size = target_size,
        batch_size = batch_size,
        subset = subset,
        interpolation = "nearest",
        seed = 1)

  this_generator = zip(image_generator, mask_generator)
  for (img,mask) in this_generator:
      if subset == 'training':
          # Apply the data augmentation
          # the same seed should provide always the same transformation and image loading
          seed = np.random.randint(100000)
          for batch_im in image_datagen.flow(img,batch_size=batch_size, seed=seed):
              break
          mask = mask.astype(np.float32)
          labels = np.unique(mask)
          if len(labels)>1:
              batch_mask = np.zeros_like(mask, dtype='float32')
              for l in range(0, len(labels)):
                  aux = (mask==l).astype(np.float32)
                  for batch_aux in mask_datagen.flow(aux,batch_size=batch_size, seed=seed):
                      break
                  batch_mask += l*(batch_aux>0).astype(np.float32)
              index = np.where(batch_mask>l)
              batch_mask[index]=l
          else:
              batch_mask = mask

          yield (batch_im,batch_mask)

      else:
          yield (img,mask)


def prepareGenerators(image_folder_path, mask_folder_path, datagen_parameters, batch_size = 4, target_size = (512, 512), validatio_split = 0.1):
  image_datagen = ImageDataGenerator(**datagen_parameters, preprocessing_function = normalizePercentile)
  mask_datagen = ImageDataGenerator(**datagen_parameters)

  train_datagen = buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, 'training', batch_size, target_size, validatio_split)
  validation_datagen = buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, 'validation', batch_size, target_size, validatio_split)

  return (train_datagen, validation_datagen)


# Normalization functions from Martin Weigert
def normalizePercentile(x, pmin=1, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):
    """This function is adapted from Martin Weigert"""
    """Percentile-based image normalization."""

    mi = np.percentile(x,pmin,axis=axis,keepdims=True)
    ma = np.percentile(x,pmax,axis=axis,keepdims=True)
    return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)


def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32
    """This function is adapted from Martin Weigert"""
    if dtype is not None:
        x   = x.astype(dtype,copy=False)
        mi  = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)
        ma  = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)
        eps = dtype(eps)

    try:
        import numexpr
        x = numexpr.evaluate("(x - mi) / ( ma - mi + eps )")
    except ImportError:
        x =                   (x - mi) / ( ma - mi + eps )

    if clip:
        x = np.clip(x,0,1)

    return x



# Simple normalization to min/max fir the Mask
def normalizeMinMax(x, dtype=np.float32):
  x = x.astype(dtype,copy=False)
  x = (x - np.amin(x)) / (np.amax(x) - np.amin(x) + 1e-10)
  return x


# This is code outlines the architecture of U-net. The choice of pooling steps decides the depth of the network.
def unet(pretrained_weights = None, input_size = (256,256,1), pooling_steps = 4, learning_rate = 1e-4, verbose=True, labels=2):
    inputs = Input(input_size)
    conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)
    conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
    # Downsampling steps
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
    conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)

    if pooling_steps > 1:
      pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
      conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
      conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)

      if pooling_steps > 2:
        pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
        conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
        conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
        drop4 = Dropout(0.5)(conv4)

        if pooling_steps > 3:
          pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
          conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
          conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
          drop5 = Dropout(0.5)(conv5)

          #Upsampling steps
          up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))
          merge6 = concatenate([drop4,up6], axis = 3)
          conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
          conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)

    if pooling_steps > 2:
      up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop4))
      if pooling_steps > 3:
        up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
      merge7 = concatenate([conv3,up7], axis = 3)
      conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
      conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)

    if pooling_steps > 1:
      up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv3))
      if pooling_steps > 2:
        up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
      merge8 = concatenate([conv2,up8], axis = 3)
      conv8 = Conv2D(128, 3, activation= 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
      conv8 = Conv2D(128, 3, activation= 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)

    if pooling_steps == 1:
      up9 = Conv2D(64, 2, padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv2))
    else:
      up9 = Conv2D(64, 2, padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8)) #activation = 'relu'

    merge9 = concatenate([conv1,up9], axis = 3)
    conv9 = Conv2D(64, 3, padding = 'same', kernel_initializer = 'he_normal')(merge9) #activation = 'relu'
    conv9 = Conv2D(64, 3, padding = 'same', kernel_initializer = 'he_normal')(conv9) #activation = 'relu'
    conv9 = Conv2D(labels, 3, padding = 'same', kernel_initializer = 'he_normal')(conv9) #activation = 'relu'
    conv10 = Conv2D(labels, 1, activation = 'softmax')(conv9)

    model = Model(inputs = inputs, outputs = conv10)

    model.compile(optimizer = Adam(lr = learning_rate), loss = 'sparse_categorical_crossentropy')

    if verbose:
      model.summary()

    if(pretrained_weights):
    	model.load_weights(pretrained_weights);

    return model

# Custom callback showing sample prediction
class SampleImageCallback(Callback):

    def __init__(self, model, sample_data, model_path, save=False):
        self.model = model
        self.sample_data = sample_data
        self.model_path = model_path
        self.save = save

    def on_epoch_end(self, epoch, logs={}):
      if np.mod(epoch,5) == 0:
            sample_predict = self.model.predict_on_batch(self.sample_data)

            f=plt.figure(figsize=(16,8))
            plt.subplot(1,labels+1,1)
            plt.imshow(self.sample_data[0,:,:,0], cmap='gray')
            plt.title('Sample source')
            plt.axis('off');
            for i in range(1, labels):
              plt.subplot(1,labels+1,i+1)
              plt.imshow(sample_predict[0,:,:,i], interpolation='nearest', cmap='magma')
              plt.title('Predicted label {}'.format(i))
              plt.axis('off');

            plt.subplot(1,labels+1,labels+1)
            plt.imshow(np.squeeze(np.argmax(sample_predict[0], axis=-1)), interpolation='nearest')
            plt.title('Semantic segmentation')
            plt.axis('off');

            plt.show()

            if self.save:
                plt.savefig(self.model_path + '/epoch_' + str(epoch+1) + '.png')
                Patch_source_list = [f for f in os.listdir(Patch_source) if not f.startswith(".")]
                random_choice = random.choice(Patch_source_list)

def predict_as_tiles(Image_path, model):

  # Read the data in and normalize
  Image_raw = io.imread(Image_path, as_gray = True)
  Image_raw = normalizePercentile(Image_raw)

  # Get the patch size from the input layer of the model
  patch_size = model.layers[0].output_shape[0][1:3]

  # Pad the image with zeros if any of its dimensions is smaller than the patch size
  if Image_raw.shape[0] < patch_size[0] or Image_raw.shape[1] < patch_size[1]:
    Image = np.zeros((max(Image_raw.shape[0], patch_size[0]), max(Image_raw.shape[1], patch_size[1])))
    Image[0:Image_raw.shape[0], 0: Image_raw.shape[1]] = Image_raw
  else:
    Image = Image_raw

  # Calculate the number of patches in each dimension
  n_patch_in_width = ceil(Image.shape[0]/patch_size[0])
  n_patch_in_height = ceil(Image.shape[1]/patch_size[1])

  prediction = np.zeros(Image.shape, dtype = 'uint8')

  for x in range(n_patch_in_width):
    for y in range(n_patch_in_height):
      xi = patch_size[0]*x
      yi = patch_size[1]*y

      # If the patch exceeds the edge of the image shift it back
      if xi+patch_size[0] >= Image.shape[0]:
        xi = Image.shape[0]-patch_size[0]

      if yi+patch_size[1] >= Image.shape[1]:
        yi = Image.shape[1]-patch_size[1]

      # Extract and reshape the patch
      patch = Image[xi:xi+patch_size[0], yi:yi+patch_size[1]]
      patch = np.reshape(patch,patch.shape+(1,))
      patch = np.reshape(patch,(1,)+patch.shape)

      # Get the prediction from the patch and paste it in the prediction in the right place
      predicted_patch = model.predict(patch, batch_size = 1)
      prediction[xi:xi+patch_size[0], yi:yi+patch_size[1]] = (np.argmax(np.squeeze(predicted_patch), axis = -1)).astype(np.uint8)


  return prediction[0:Image_raw.shape[0], 0: Image_raw.shape[1]]


def saveResult(save_path, nparray, source_dir_list, prefix=''):
  for (filename, image) in zip(source_dir_list, nparray):
      io.imsave(os.path.join(save_path, prefix+os.path.splitext(filename)[0]+'.tif'), image) # saving as unsigned 8-bit image


def convert2Mask(image, threshold):
  mask = img_as_ubyte(image, force_copy=True)
  mask[mask > threshold] = 255
  mask[mask <= threshold] = 0
  return mask

# BMZ model export functions
def make_author(author_input_info: str):
    """
    Create an Author object from a string input.

    Args:
        author_input_info: A string containing the author's name and affiliation.

    Returns:
        An Author object
    """
    auth_order = ['name', 'affiliation', 'email', 'orcid', 'github_user']
    auth_dict = {}

    auth_info_split = author_input_info.split(',')

    for i in range(len(auth_info_split)):
        if auth_info_split[i].strip() == 'None' or auth_info_split[i].strip() == '':
            continue
        else:
            auth_dict[auth_order[i]] = auth_info_split[i].strip()
        print(auth_dict)

    return bioimageio_spec.Author(**auth_dict) 

def make_maintainer(maintainer_input_info: str):
    """
    Create an Author object from a string input.

    Args:
        author_input_info: A string containing the author's name and affiliation.

    Returns:
        An Author object
    """
    maint_order = [ 'github_user', 'name', 'affiliation', 'email', 'orcid']
    maint_dict = {}

    maint_info_split = maintainer_input_info.split(',')

    for i in range(len(maint_info_split)):
        if maint_info_split[i].strip() == 'None' or maint_info_split[i].strip() == '':
            continue
        else:
            maint_dict[maint_order[i]] = maint_info_split[i].strip()
        print(maint_dict)

    return bioimageio_spec.Maintainer(**maint_dict) 


# -------------- Other definitions -----------
W  = '\033[0m'  # white (normal)
R  = '\033[31m' # red
prediction_prefix = 'Predicted_'


print('-------------------')
print('U-Net and dependencies installed.')

# Colors for the warning messages
class bcolors:
  WARNING = '\033[31m'

# Check if this is the latest version of the notebook

All_notebook_versions = pd.read_csv("https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv", dtype=str)
print('Notebook version: '+Notebook_version)
Latest_Notebook_version = All_notebook_versions[All_notebook_versions["Notebook"] == Network]['Version'].iloc[0]
print('Latest notebook version: '+Latest_Notebook_version)
if Notebook_version == Latest_Notebook_version:
  print("This notebook is up-to-date.")
else:
  print(bcolors.WARNING +"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki")


def pdf_export(trained = False, augmentation = False, pretrained_model = False):
  class MyFPDF(FPDF, HTMLMixin):
    pass

  pdf = MyFPDF()
  pdf.add_page()
  pdf.set_right_margin(-1)
  pdf.set_font("Arial", size = 11, style='B')

  day = datetime.now()
  datetime_str = str(day)[0:10]

  Header = 'Training report for '+Network+' model ('+model_name+')\nDate: '+datetime_str
  pdf.multi_cell(180, 5, txt = Header, align = 'L')
  pdf.ln(1)

  # add another cell
  if trained:
    training_time = "Training time: "+str(hour)+ "hour(s) "+str(mins)+"min(s) "+str(round(sec))+"sec(s)"
    pdf.cell(190, 5, txt = training_time, ln = 1, align='L')
  pdf.ln(1)

  Header_2 = 'Information for your materials and method:'
  pdf.cell(190, 5, txt=Header_2, ln=1, align='L')
  pdf.ln(1)

  all_packages = ''
  for requirement in freeze(local_only=True):
    all_packages = all_packages+requirement+', '
  #print(all_packages)

  #Main Packages
  main_packages = ''
  version_numbers = []
  for name in ['tensorflow','numpy','keras']:
    find_name=all_packages.find(name)
    main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '
    #Version numbers only here:
    version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])

  try:
    cuda_version = subprocess.run(["nvcc","--version"],stdout=subprocess.PIPE)
    cuda_version = cuda_version.stdout.decode('utf-8')
    cuda_version = cuda_version[cuda_version.find(', V')+3:-1]
  except:
    cuda_version = ' - No cuda found - '
  try:
    gpu_name = subprocess.run(["nvidia-smi"],stdout=subprocess.PIPE)
    gpu_name = gpu_name.stdout.decode('utf-8')
    gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]
  except:
    gpu_name = ' - No GPU found - '
  #print(cuda_version[cuda_version.find(', V')+3:-1])
  #print(gpu_name)
  loss = str(model.loss)[str(model.loss).find('function')+len('function'):str(model.loss).find('.<')]
  Training_source_list = [f for f in os.listdir(Training_source) if not f.startswith(".")]
  shape = io.imread(Training_source+'/' + Training_source_list[1]).shape
  dataset_size = len(Training_source_list)

  text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(number_of_training_dataset)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_width)+','+str(patch_height)+')) with a batch size of '+str(batch_size)+' and a'+loss+' loss function,'+' using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'

  if pretrained_model:
    text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(number_of_training_dataset)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_width)+','+str(patch_height)+')) with a batch size of '+str(batch_size)+'  and a'+loss+' loss function,'+' using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'

  pdf.set_font('')
  pdf.set_font_size(10.)
  pdf.multi_cell(180, 5, txt = text, align='L')
  pdf.ln(1)
  pdf.set_font('')
  pdf.set_font('Arial', size = 10, style = 'B')
  pdf.cell(28, 5, txt='Augmentation: ', ln=1)
  pdf.set_font('')
  if augmentation:
    aug_text = 'The dataset was augmented by'
    if rotation_range != 0:
      aug_text = aug_text+'\n- rotation'
    if horizontal_flip == True or vertical_flip == True:
      aug_text = aug_text+'\n- flipping'
    if zoom_range != 0:
      aug_text = aug_text+'\n- random zoom magnification'
    if horizontal_shift != 0 or vertical_shift != 0:
      aug_text = aug_text+'\n- shifting'
    if shear_range != 0:
      aug_text = aug_text+'\n- image shearing'
  else:
    aug_text = 'No augmentation was used for training.'
  pdf.multi_cell(190, 5, txt=aug_text, align='L')
  pdf.ln(1)
  pdf.set_font('Arial', size = 11, style = 'B')
  pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)
  pdf.set_font('')
  pdf.set_font_size(10.)
  if Use_Default_Advanced_Parameters:
    pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')
  pdf.cell(200, 5, txt='The following parameters were used for training:')
  pdf.ln(1)
  html = """
  <table width=40% style="margin-left:0px;">
    <tr>
      <th width = 50% align="left">Parameter</th>
      <th width = 50% align="left">Value</th>
    </tr>
    <tr>
      <td width = 50%>number_of_epochs</td>
      <td width = 50%>{0}</td>
    </tr>
    <tr>
      <td width = 50%>patch_size</td>
      <td width = 50%>{1}</td>
    </tr>
    <tr>
      <td width = 50%>batch_size</td>
      <td width = 50%>{2}</td>
    </tr>
    <tr>
      <td width = 50%>number_of_steps</td>
      <td width = 50%>{3}</td>
    </tr>
    <tr>
      <td width = 50%>percentage_validation</td>
      <td width = 50%>{4}</td>
    </tr>
    <tr>
      <td width = 50%>initial_learning_rate</td>
      <td width = 50%>{5}</td>
    </tr>
    <tr>
      <td width = 50%>pooling_steps</td>
      <td width = 50%>{6}</td>
    </tr>
    <tr>
      <td width = 50%>min_fraction</td>
      <td width = 50%>{7}</td>
  </table>
  """.format(number_of_epochs, str(patch_width)+'x'+str(patch_height), batch_size, number_of_steps, percentage_validation, initial_learning_rate, pooling_steps, min_fraction)
  pdf.write_html(html)

  #pdf.multi_cell(190, 5, txt = text_2, align='L')
  pdf.set_font("Arial", size = 11, style='B')
  pdf.ln(1)
  pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)
  pdf.set_font('')
  pdf.set_font('Arial', size = 10, style = 'B')
  pdf.cell(29, 5, txt= 'Training_source:', align = 'L', ln=0)
  pdf.set_font('')
  pdf.multi_cell(170, 5, txt = Training_source, align = 'L')
  pdf.ln(1)
  pdf.set_font('')
  pdf.set_font('Arial', size = 10, style = 'B')
  pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)
  pdf.set_font('')
  pdf.multi_cell(170, 5, txt = Training_target, align = 'L')
  pdf.ln(1)
  pdf.set_font('')
  pdf.set_font('Arial', size = 10, style = 'B')
  pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)
  pdf.set_font('')
  pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')
  pdf.ln(1)
  pdf.cell(60, 5, txt = 'Example Training pair', ln=1)
  pdf.ln(1)
  exp_size = io.imread(base_path + '/TrainingDataExample_Unet2D.png').shape
  pdf.image(base_path + '/TrainingDataExample_Unet2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))
  pdf.ln(1)
  ref_1 = 'References:\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. "Democratising deep learning for microscopy with ZeroCostDL4Mic." Nature Communications (2021).'
  pdf.multi_cell(190, 5, txt = ref_1, align='L')
  pdf.ln(1)
  ref_2 = '- Unet: Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image segmentation." International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.'
  pdf.multi_cell(190, 5, txt = ref_2, align='L')
  # if Use_Data_augmentation:
  #   ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. "Augmentor: an image augmentation library for machine learning." arXiv preprint arXiv:1708.04680 (2017).'
  #   pdf.multi_cell(190, 5, txt = ref_3, align='L')
  pdf.ln(3)
  reminder = 'Important:\nRemember to perform the quality control step on all newly trained models\nPlease consider depositing your training dataset on Zenodo'
  pdf.set_font('Arial', size = 11, style='B')
  pdf.multi_cell(190, 5, txt=reminder, align='C')
  pdf.ln(1)

  pdf.output(model_path+'/'+model_name+'/'+model_name+'_training_report.pdf')

  print('------------------------------')
  print('PDF report exported in '+model_path+'/'+model_name+'/')

def qc_pdf_export():
  class MyFPDF(FPDF, HTMLMixin):
    pass

  pdf = MyFPDF()
  pdf.add_page()
  pdf.set_right_margin(-1)
  pdf.set_font("Arial", size = 11, style='B')

  Network = 'Unet 2D'

  day = datetime.now()
  datetime_str = str(day)[0:10]

  Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\nDate: '+datetime_str
  pdf.multi_cell(180, 5, txt = Header, align = 'L')
  pdf.ln(1)

  all_packages = ''
  for requirement in freeze(local_only=True):
    all_packages = all_packages+requirement+', '

  pdf.set_font('')
  pdf.set_font('Arial', size = 11, style = 'B')
  pdf.ln(2)
  pdf.cell(190, 5, txt = 'Loss curves', ln=1, align='L')
  pdf.ln(1)
  exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape
  if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):
    pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/12), h = round(exp_size[0]/3))
  else:
    pdf.set_font('')
    pdf.set_font('Arial', size=10)
    pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.',align='L')
  pdf.ln(2)
  pdf.set_font('')
  pdf.set_font('Arial', size = 10, style = 'B')
  pdf.ln(3)
  pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)
  pdf.ln(1)
  exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape
  pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))
  pdf.ln(1)
  pdf.set_font('')
  pdf.set_font('Arial', size = 11, style = 'B')
  pdf.ln(1)
  pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)
  pdf.set_font('')
  pdf.set_font_size(10.)

  pdf.ln(1)
  html = """
  <body>
  <font size="10" face="Courier" >
  <table width=60% style="margin-left:0px;">"""
  with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:
    metrics = csv.reader(csvfile)
    header = next(metrics)
    image = header[0]
    IoU = header[-1]
    header = """
    <tr>
    <th width = 33% align="center">{0}</th>
    <th width = 33% align="center">{1}</th>
    </tr>""".format(image,IoU)
    html = html+header
    i=0
    for row in metrics:
      i+=1
      image = row[0]
      IoU = row[-1]
      cells = """
        <tr>
          <td width = 33% align="center">{0}</td>
          <td width = 33% align="center">{1}</td>
        </tr>""".format(image,str(round(float(IoU),3)))
      html = html+cells
    html = html+"""</body></table>"""

  pdf.write_html(html)

  pdf.ln(1)
  pdf.set_font('')
  pdf.set_font_size(10.)
  ref_1 = 'References:\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. "Democratising deep learning for microscopy with ZeroCostDL4Mic." Nature Communications (2021).'
  pdf.multi_cell(190, 5, txt = ref_1, align='L')
  pdf.ln(1)
  ref_2 = '- Unet: Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image segmentation." International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.'
  pdf.multi_cell(190, 5, txt = ref_2, align='L')
  pdf.ln(3)
  reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'

  pdf.set_font('Arial', size = 11, style='B')
  pdf.multi_cell(190, 5, txt=reminder, align='C')
  pdf.ln(1)

  pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')

  print('------------------------------')
  print('QC PDF report exported as '+full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')

# Build requirements file for local run
after = [str(m) for m in sys.modules]
build_requirements_file(before, after)
!pip3 freeze > requirements.txt

# **2. Complete the Colab session**




---







## **2.1. Check for GPU access**
---

By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:

<font size = 4>Go to **Runtime -> Change the Runtime type**

<font size = 4>**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*

<font size = 4>**Accelerator: GPU** *(Graphics processing unit)*


In [None]:
#@markdown ##Run this cell to check if you have GPU access

if tf.test.gpu_device_name()=='':
  print('You do not have GPU access.')
  print('Did you change your runtime ?')
  print('If the runtime setting is correct then Google did not allocate a GPU for your session')
  print('Expect slow performance. To access GPU try reconnecting later')

else:
  print('You have GPU access')
  !nvidia-smi

# from tensorflow.python.client import device_lib
# device_lib.list_local_devices()

# print the tensorflow version
print('Tensorflow version is ' + str(tf.__version__))


## **2.2. Mount your Google Drive**
---
<font size = 4> To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.

<font size = 4> Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive.

<font size = 4> Once this is done, your data are available in the **Files** tab on the top left of notebook.

In [None]:
#@markdown ##Play the cell to connect your Google Drive to Colab

#@markdown * Click on the URL.

#@markdown * Sign in your Google Account.

#@markdown * Copy the authorization code.

#@markdown * Enter the authorization code.

#@markdown * Click on "Files" site on the right. Refresh the site. Your Google Drive folder should now be available here as "drive".

# mount user's Google Drive to Google Colab.
from google.colab import drive
drive.mount('/content/gdrive')

**<font size = 4> If you cannot see your files, reactivate your session by connecting to your hosted runtime.**


<img width="40%" alt ="Example of image detection with retinanet." src="https://github.com/HenriquesLab/ZeroCostDL4Mic/raw/master/Wiki_files/connect_to_hosted.png"><figcaption> Connect to a hosted runtime. </figcaption>

# **3. Select your parameters and paths**
---

## **3.1. Setting main training parameters**
---
<font size = 4>

<font size = 5> **Paths for training data and models**

<font size = 4>**`Training_source`, `Training_target`:** These are the folders containing your source (e.g. EM images) and target files (semantic segmentation masks).  The mask should be a unique 2D image with values 0, 1, 2, ... each of them corresponding to a semantic definition of the content in the image. The values should be ordered from the lowest to the highest and without missing any value in between (unless it is missing in the image). Enter the path to the source and target images for training. **These should be located in the same parent folder.**

<font size = 4>**`model_name`:** Use only my_model -style, not my-model. If you want to use a previously trained model, enter the name of the pretrained model (which should be contained in the trained_model -folder after training).

<font size = 4>**`model_path`**: Enter the path of the folder where you want to save your model.

<font size = 4>**`visual_validation_after_training`**: If you select this option, a random image pair will be set aside from your training set and will be used to display a predicted image of the trained network next to the input and the ground-truth. This can aid in visually assessing the performance of your network after training. **Note: Your training set size will decrease by 1 if you select this option.**

<font size = 4>**`labels`**: The number of different labels that the network needs to learn, which also includes the background. For example: to segment two different kind of objects in an image (cats and dogs), labels = 3 (2 labels for the two kinds and one more label for the background).


<font size = 5> **Select training parameters**

<font size = 4>**`number_of_epochs`**: Choose more epochs for larger training sets. Observing how much the loss reduces between epochs during training may help determine the optimal value. **Default: 200**

<font size = 5>**Advanced parameters - experienced users only**

<font size = 4>**`batch_size`**: This parameter describes the amount of images that are loaded into the network per step. Smaller batchsizes may improve training performance slightly but may increase training time. If the notebook crashes while loading the dataset this can be due to a too large batch size. Decrease the number in this case. **Default: 4**

<font size = 4>**`number_of_steps`**: This number should be equivalent to the number of samples in the training set divided by the batch size, to ensure the training iterates through the entire training set. The default value is calculated to ensure this. This behaviour can also be obtained by setting it to 0. Other values can be used for testing.

<font size = 4> **`pooling_steps`**: Choosing a different number of pooling layers can affect the performance of the network. Each additional pooling step will also two additional convolutions. The network can learn more complex information but is also more likely to overfit. Achieving best performance may require testing different values here. **Default: 2**

<font size = 4>**`percentage_validation`:**  Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10**

<font size = 4>**`initial_learning_rate`:**  Input the initial value to be used as learning rate. **Default value: 0.0003**

<font size = 4>**`patch_width` and `patch_height`:** The notebook crops the data in patches of fixed size prior to training. The dimensions of the patches can be defined here. When `Use_Default_Advanced_Parameters` is selected, the largest 2^n x 2^n patch that fits in the smallest dataset is chosen. Larger patches than 512x512 should **NOT** be selected for network stability.

<font size = 4>**`min_fraction`:** Minimum fraction of pixels being foreground for a slected patch to be considered valid. It should be between 0 and 1.**Default value: 0.02** (2%)



In [None]:
# ------------- Initial user input ------------
#@markdown ###Path to training images:
Training_source = '' #@param {type:"string"}
Training_target = '' #@param {type:"string"}

model_name = '' #@param {type:"string"}
model_path = '' #@param {type:"string"}

labels = 3 #@param {type:"number"}

#@markdown ###Training parameters:
#@markdown Number of epochs
number_of_epochs =  10#@param {type:"number"}

#@markdown ###Advanced parameters:
Use_Default_Advanced_Parameters = True #@param {type:"boolean"}

#@markdown ###If not, please input:
batch_size =  5#@param {type:"integer"}
number_of_steps =  0#@param {type:"number"}
pooling_steps = 3 #@param [1,2,3,4]{type:"raw"}
percentage_validation =  10#@param{type:"number"}
initial_learning_rate = 0.0001 #@param {type:"number"}

patch_width =  320#@param{type:"number"}
patch_height =  320#@param{type:"number"}
min_fraction = 0.05#@param{type:"number"}


# ------------- Initialising folder, variables and failsafes ------------
#  Create the folders where to save the model and the QC
full_model_path = os.path.join(model_path, model_name)
if os.path.exists(full_model_path):
  print(R+'!! WARNING: Folder already exists and will be overwritten !!'+W)

if (Use_Default_Advanced_Parameters):
  print("Default advanced parameters enabled")
  batch_size = 4
  pooling_steps = 2
  percentage_validation = 10
  initial_learning_rate = 0.0003
  patch_width, patch_height = estimatePatchSize(Training_source)
  min_fraction = 0.02


#The create_patches function will create the two folders below
# Patch_source = '/content/img_patches'
# Patch_target = '/content/mask_patches'
print('Training on patches of size (x,y): ('+str(patch_width)+','+str(patch_height)+')')

#Create patches
print('Creating patches...')
Patch_source, Patch_target = create_patches(Training_source, Training_target, patch_width, patch_height, min_fraction)

number_of_training_dataset = len(os.listdir(Patch_source))
print('Total number of valid patches: '+str(number_of_training_dataset))

if Use_Default_Advanced_Parameters or number_of_steps == 0:
  number_of_steps = ceil((100-percentage_validation)/100*number_of_training_dataset/batch_size)
print('Number of steps: '+str(number_of_steps))

# Calculate the number of steps to use for validation
validation_steps = max(1, ceil(percentage_validation/100*number_of_training_dataset/batch_size))
validatio_split = percentage_validation/100

# Here we disable pre-trained model by default (in case the next cell is not ran)
Use_pretrained_model = False
# Here we disable data augmentation by default (in case the cell is not ran)
Use_Data_augmentation = False
# Build the default dict for the ImageDataGenerator
data_gen_args = dict(width_shift_range = 0.,
                     height_shift_range = 0.,
                     rotation_range = 0., #90
                     zoom_range = 0.,
                     shear_range = 0.,
                     horizontal_flip = False,
                     vertical_flip = False,
                     validation_split = percentage_validation/100,
                     fill_mode = 'reflect')

# ------------- Display ------------

#if not os.path.exists('/content/img_patches/'):
random_choice = random.choice(os.listdir(Patch_source))
x = io.imread(os.path.join(Patch_source, random_choice))

#os.chdir(Training_target)
y = io.imread(os.path.join(Patch_target, random_choice), as_gray=True)

f=plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.imshow(x, interpolation='nearest',cmap='gray')
plt.title('Training image patch')
plt.axis('off');

plt.subplot(1,2,2)
plt.imshow(y, interpolation='nearest',cmap='gray')
plt.title('Training mask patch')
plt.axis('off');

plt.savefig(base_path + '/TrainingDataExample_Unet2D.png',bbox_inches='tight',pad_inches=0)



## **3.2. Data augmentation**

---

<font size = 4> Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if the dataset is large the values can be set to 0.

<font size = 4> The augmentation options below are to be used as follows:

* <font size = 4> **shift**: a translation of the image by a fraction of the image size (width or height), **default: 10%**
* **zoom_range**: Increasing or decreasing the field of view. E.g. 10% will result in a zoom range of (0.9 to 1.1), with pixels added or interpolated, depending on the transformation, **default: 10%**
* **shear_range**: Shear angle in counter-clockwise direction, **default: 10%**
* **flip**: creating a mirror image along specified axis (horizontal or vertical), **default: True**
* **rotation_range**: range of allowed rotation angles in degrees (from 0 to *value*), **default: 180**

In [None]:
#@markdown ##**Augmentation options**

Use_Data_augmentation = True #@param {type:"boolean"}
Use_Default_Augmentation_Parameters = True #@param {type:"boolean"}

if Use_Data_augmentation:
  if Use_Default_Augmentation_Parameters:
    horizontal_shift =  10
    vertical_shift =  20
    zoom_range =  10
    shear_range =  10
    horizontal_flip = True
    vertical_flip = True
    rotation_range =  180
#@markdown ###If you are not using the default settings, please provide the values below:

#@markdown ###**Image shift, zoom, shear and flip (%)**
  else:
    horizontal_shift =  10 #@param {type:"slider", min:0, max:100, step:1}
    vertical_shift =  13 #@param {type:"slider", min:0, max:100, step:1}
    zoom_range =  10 #@param {type:"slider", min:0, max:100, step:1}
    shear_range =  14 #@param {type:"slider", min:0, max:100, step:1}
    horizontal_flip = True #@param {type:"boolean"}
    vertical_flip = True #@param {type:"boolean"}

#@markdown ###**Rotate image within angle range (degrees):**
    rotation_range =  180 #@param {type:"slider", min:0, max:180, step:1}

#given behind the # are the default values for each parameter.

else:
  horizontal_shift =  0
  vertical_shift =  0
  zoom_range =  0
  shear_range =  0
  horizontal_flip = False
  vertical_flip = False
  rotation_range =  0


# Build the dict for the ImageDataGenerator
data_gen_args = dict(width_shift_range = horizontal_shift/100.,
                     height_shift_range = vertical_shift/100.,
                     rotation_range = rotation_range, #90
                     zoom_range = zoom_range/100.,
                     shear_range = shear_range/100.,
                     horizontal_flip = horizontal_flip,
                     vertical_flip = vertical_flip,
                     validation_split = percentage_validation/100,
                     fill_mode = 'reflect')



# ------------- Display ------------
dir_augmented_data_imgs = base_path + "/augment_img"
dir_augmented_data_masks = base_path + "/augment_mask"
random_choice = random.choice(os.listdir(Patch_source))
orig_img = load_img(os.path.join(Patch_source,random_choice))
orig_mask = load_img(os.path.join(Patch_target,random_choice))

augment_view = ImageDataGenerator(**data_gen_args)

if Use_Data_augmentation:
  print("Parameters enabled")
  print("Here is what a subset of your augmentations looks like:")
  save_augment(augment_view, orig_img, dir_augmented_data=dir_augmented_data_imgs)
  save_augment(augment_view, orig_mask, dir_augmented_data=dir_augmented_data_masks)

  fig = plt.figure(figsize=(15, 7))
  fig.subplots_adjust(hspace=0.0,wspace=0.1,left=0,right=1.1,bottom=0, top=0.8)


  ax = fig.add_subplot(2, 6, 1,xticks=[],yticks=[])
  new_img=img_as_ubyte(normalizeMinMax(img_to_array(orig_img)))
  ax.imshow(new_img)
  ax.set_title('Original Image')
  i = 2
  for imgnm in os.listdir(dir_augmented_data_imgs):
    ax = fig.add_subplot(2, 6, i,xticks=[],yticks=[])
    img = load_img(dir_augmented_data_imgs + "/" + imgnm)
    ax.imshow(img)
    i += 1

  ax = fig.add_subplot(2, 6, 7,xticks=[],yticks=[])
  new_mask=img_as_ubyte(normalizeMinMax(img_to_array(orig_mask)))
  ax.imshow(new_mask)
  ax.set_title('Original Mask')
  j=2
  for imgnm in os.listdir(dir_augmented_data_masks):
    ax = fig.add_subplot(2, 6, j+6,xticks=[],yticks=[])
    mask = load_img(dir_augmented_data_masks + "/" + imgnm)
    ax.imshow(mask)
    j += 1
  plt.show()

else:
  print("No augmentation will be used")




## **3.3. Using weights from a pre-trained model as initial weights**
---
<font size = 4>  Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a U-Net model**.

<font size = 4> **You do not need to run this section if you want to train a network from scratch**.

<font size = 4> This option allows you to use pre-trained models from the [BioImage Model Zoo](https://bioimage.io/#/) and fine-tune them to analyse new data. Choose `bioimageio_model` and provide the ID in `bioimageio_model_id` (e.g., "creative-panda" or "10.5281/zenodo.5817052").

<font size = 4> This option also allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. Choose `Model_from_file` and provide the `pretrained_model_path`.

<font size = 4> In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used.

In [None]:
# @markdown ##Loading weights from a pre-trained network

Use_pretrained_model = False #@param {type:"boolean"}
pretrained_model_choice = "BioImage Model Zoo" #@param ["Model_from_file", "BioImage Model Zoo"]
Weights_choice = "best" #@param ["last", "best"]


#@markdown ###If you chose "Model_from_file", please provide the path to the model folder:
pretrained_model_path = "" #@param {type:"string"}

#@markdown ###If you chose "BioImage Model Zoo", please provide the path or the URL to the model:
bioimageio_model_id = "" #@param {type:"string"}

# --------------------- Check if we load a previously trained model ------------------------
if Use_pretrained_model:

# --------------------- Load the model from the choosen path ------------------------
  if pretrained_model_choice == "Model_from_file":
    h5_file_path = os.path.join(pretrained_model_path, "weights_"+Weights_choice+".hdf5")
    qc_path = os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')
  elif pretrained_model_choice == "BioImage Model Zoo":

    model_spec = load_resource_description(bioimageio_model_id)
    if "keras_hdf5" not in biomodel.weights:
      print("Invalid bioimageio model")
      h5_file_path = "no-model"
      qc_path = "no-qc"
    else:
      h5_file_path = str(biomodel.weights["keras_hdf5"].source)
      try:
        attachments = biomodel.attachments.files
        qc_path = [fname for fname in attachments if fname.endswith("training_evaluation.csv")][0]
        qc_path = os.path.join(base_path + "//bioimageio_pretrained_model", qc_path)
      except Exception:
        qc_path = "no-qc"

# --------------------- Check the model exist ------------------------
# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled,
  if not os.path.exists(h5_file_path):
    print(R+'WARNING: pretrained model does not exist')
    Use_pretrained_model = False


# If the model path contains a pretrain model, we load the training rate,
  if os.path.exists(h5_file_path):
#Here we check if the learning rate can be loaded from the quality control folder
    # if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):
    if os.path.exists(qc_path):

      # with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:
      with open(qc_path,'r') as csvfile:
        csvRead = pd.read_csv(csvfile, sep=',')
        #print(csvRead)

        if "learning rate" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)
          print("pretrained network learning rate found")
          #find the last learning rate
          lastLearningRate = csvRead["learning rate"].iloc[-1]
          #Find the learning rate corresponding to the lowest validation loss
          min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]
          #print(min_val_loss)
          bestLearningRate = min_val_loss['learning rate'].iloc[-1]

          if Weights_choice == "last":
            print('Last learning rate: '+str(lastLearningRate))

          if Weights_choice == "best":
            print('Learning rate of best validation loss: '+str(bestLearningRate))

        if not "learning rate" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead
          bestLearningRate = initial_learning_rate
          lastLearningRate = initial_learning_rate
          print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)

#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used
    if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):
      print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)
      bestLearningRate = initial_learning_rate
      lastLearningRate = initial_learning_rate


# Display info about the pretrained model to be loaded (or not)
if Use_pretrained_model:
  print('Weights found in:')
  print(h5_file_path)
  print('will be loaded prior to training.')

else:
  print(R+'No pretrained network will be used.')




# **4. Train the network**
---
#### **Troubleshooting:** If you receive a time-out or exhausted error, try reducing the batchsize of your training set. This reduces the amount of data loaded into the model at one point in time.

## **4.1. Prepare the training data and model for training**
---
<font size = 4>Here, we use the information from 3. to build the model and convert the training data into a suitable format for training.

In [None]:
#@markdown ##Play this cell to prepare the model for training


# ------------------ Set the generators, model and logger ------------------
# This will take the image size and set that as a patch size (arguable...)
# Read image size (without actuall reading the data)

(train_datagen, validation_datagen) = prepareGenerators(Patch_source,
                                                        Patch_target,
                                                        data_gen_args,
                                                        batch_size,
                                                        target_size = (patch_width, patch_height),
                                                        validatio_split = validatio_split)


# This modelcheckpoint will only save the best model from the validation loss point of view
model_checkpoint = ModelCheckpoint(os.path.join(full_model_path, 'weights_best.hdf5'),
                                   monitor='val_loss',verbose=1, save_best_only=True)

# --------------------- Using pretrained model ------------------------
#Here we ensure that the learning rate set correctly when using pre-trained models
if Use_pretrained_model:
  if Weights_choice == "last":
    initial_learning_rate = lastLearningRate

  if Weights_choice == "best":
    initial_learning_rate = bestLearningRate
else:
  h5_file_path = None

# --------------------- ---------------------- ------------------------

# --------------------- Reduce learning rate on plateau ------------------------

reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, verbose=1,
                              mode='auto', patience=20, min_lr=0)
# --------------------- ---------------------- ------------------------

# Define the model
model = unet(input_size = (patch_width,patch_height,1),
            pooling_steps = pooling_steps,
            learning_rate = initial_learning_rate,
            labels = labels)

# --------------------- Using pretrained model ------------------------
# Load the pretrained weights
if Use_pretrained_model:
  try:
      model.load_weights(h5_file_path)
  except:
      print(bcolors.WARNING + "The pretrained model could not be loaded as the configuration of the network is different.")
      print("Please, read the model specifications and check the parameters in Section 3.1" + W)

# except:
#   print("The pretrained model could not be loaded. Please, check the parameters of the pre-trained model architecture.")
config_model= model.optimizer.get_config()
print(config_model)


# ------------------ Failsafes ------------------
if os.path.exists(full_model_path):
  print(R+'!! WARNING: Model folder already existed and has been removed !!'+W)
  shutil.rmtree(full_model_path)

os.makedirs(full_model_path)
os.makedirs(os.path.join(full_model_path,'Quality Control'))


# ------------------ Display ------------------
print('---------------------------- Main training parameters ----------------------------')
print('Number of epochs: '+str(number_of_epochs))
print('Batch size: '+str(batch_size))
print('Number of training dataset: '+str(number_of_training_dataset))
print('Number of training steps: '+str(number_of_steps))
print('Number of validation steps: '+str(validation_steps))
print('---------------------------- ------------------------ ----------------------------')

pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)


## **4.2. Start Training**
---
<font size = 4>When playing the cell below you should see updates after each epoch (round). Network training can take some time.

<font size = 4>* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.

<font size = 4>Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder.

In [None]:
#@markdown ##Start training

start = time.time()



random_choice = random.choice(os.listdir(Patch_source))
x = io.imread(os.path.join(Patch_source, random_choice))
sample_batch = np.expand_dims(normalizePercentile(x), axis = [0, -1])
sample_img = SampleImageCallback(model, sample_batch, os.path.join(model_path, model_name))

history = model.fit_generator(train_datagen, steps_per_epoch = number_of_steps,
                              epochs = number_of_epochs,
                              callbacks=[model_checkpoint, reduce_lr, sample_img],
                              validation_data = validation_datagen,
                              validation_steps = 3, shuffle=True, verbose=1)

# Save the last model
model.save(os.path.join(full_model_path, 'weights_last.hdf5'))


# convert the history.history dict to a pandas DataFrame:
lossData = pd.DataFrame(history.history)

# The training evaluation.csv is saved (overwrites the Files if needed).
lossDataCSVpath = os.path.join(full_model_path,'Quality Control/training_evaluation.csv')
with open(lossDataCSVpath, 'w') as f:
  writer = csv.writer(f)
  writer.writerow(['loss','val_loss', 'learning rate'])
  for i in range(len(history.history['loss'])):
    writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])



# Displaying the time elapsed for training
print("------------------------------------------")
dt = time.time() - start
mins, sec = divmod(dt, 60)
hour, mins = divmod(mins, 60)
print("Time elapsed:", hour, "hour(s)", mins,"min(s)",round(sec),"sec(s)")
print("------------------------------------------")

#Create a pdf document with training summary

pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)


# **5. Evaluate your model**
---

<font size = 4>This section allows the user to perform important quality checks on the validity and generalisability of the trained model.

<font size = 4>**We highly recommend to perform quality control on all newly trained models.**



In [None]:
#@markdown ###Do you want to assess the model you just trained ?

Use_the_current_trained_model = True #@param {type:"boolean"}

#@markdown ###If not, please provide the path to the model folder:

QC_model_folder = "" #@param {type:"string"}

#Here we define the loaded model name and path
QC_model_name = os.path.basename(QC_model_folder)
QC_model_path = os.path.dirname(QC_model_folder)


if (Use_the_current_trained_model):
  print("Using current trained network")
  QC_model_name = model_name
  QC_model_path = model_path
else:
  # These are used in section 6
  model_name = QC_model_name
  model_path = QC_model_path

full_QC_model_path = os.path.join(QC_model_path, QC_model_name)
if os.path.exists(os.path.join(full_QC_model_path, 'weights_best.hdf5')):
  print("The "+QC_model_name+" network will be evaluated")
else:
  print(R+'!! WARNING: The chosen model does not exist !!'+W)
  print('Please make sure you provide a valid model path and model name before proceeding further.')



## **5.1. Inspection of the loss function**
---

<font size = 4>First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*

<font size = 4>**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.

<font size = 4>**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.

<font size = 4>During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.

<font size = 4>Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.

In [None]:
#@markdown ##Play the cell to show a plot of training errors vs. epoch number

epochNumber = []
lossDataFromCSV = []
vallossDataFromCSV = []

with open(os.path.join(full_QC_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:
    csvRead = csv.reader(csvfile, delimiter=',')
    next(csvRead)
    for row in csvRead:
        lossDataFromCSV.append(float(row[0]))
        vallossDataFromCSV.append(float(row[1]))

epochNumber = range(len(lossDataFromCSV))

plt.figure(figsize=(15,10))

plt.subplot(2,1,1)
plt.plot(epochNumber,lossDataFromCSV, label='Training loss')
plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')
plt.title('Training loss and validation loss vs. epoch number (linear scale)')
plt.ylabel('Loss')
plt.xlabel('Epoch number')
plt.legend()

plt.subplot(2,1,2)
plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')
plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')
plt.title('Training loss and validation loss vs. epoch number (log scale)')
plt.ylabel('Loss')
plt.xlabel('Epoch number')
plt.legend()
plt.savefig(os.path.join(full_QC_model_path, 'Quality Control', 'lossCurvePlots.png'),bbox_inches='tight',pad_inches=0)
plt.show()



## **5.2. Error mapping and quality metrics estimation**
---
<font size = 4>This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder. The result for one of the image will also be displayed.

<font size = 4>The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei.

<font size = 4>The Input, Ground Truth, Prediction and IoU maps are shown below for the last example in the QC set.

<font size = 4> The results for all QC examples can be found in the "*Quality Control*" folder which is located inside your "model_folder".

### **Thresholds for image masks**

<font size = 4> Since the output from Unet is not a binary mask, the output images are converted to binary masks using thresholding. This section will test different thresholds (from 0 to 255) to find the one yielding the best IoU score compared with the ground truth. The best threshold for each image and the average of these thresholds will be displayed below. **These values can be a guideline when creating masks for unseen data in section 6.**

In [None]:
# ------------- User input ------------
#@markdown ##Choose the folders that contain your Quality Control dataset
Source_QC_folder = "" #@param{type:"string"}
Target_QC_folder = "" #@param{type:"string"}


# ------------- Initialise folders ------------
# Create a quality control/Prediction Folder
prediction_QC_folder = os.path.join(full_QC_model_path, 'Quality Control', 'Prediction')
if os.path.exists(prediction_QC_folder):
  shutil.rmtree(prediction_QC_folder)

os.makedirs(prediction_QC_folder)


# ------------- Prepare the model and run predictions ------------

# Load the model
unet = load_model(os.path.join(full_QC_model_path, 'weights_best.hdf5'), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})
labels =  unet.output_shape[-1]
Input_size = unet.layers[0].output_shape[0][1:3]
print('Model input size: '+str(Input_size[0])+'x'+str(Input_size[1]))

# Create a list of sources
source_dir_list = os.listdir(Source_QC_folder)
number_of_dataset = len(source_dir_list)
print('Number of dataset found in the folder: '+str(number_of_dataset))

predictions = []
for i in tqdm(range(number_of_dataset)):
  predictions.append(predict_as_tiles(os.path.join(Source_QC_folder, source_dir_list[i]), unet))


# Save the results in the folder along with the masks according to the set threshold
saveResult(prediction_QC_folder, predictions, source_dir_list, prefix=prediction_prefix)

#-----------------------------Calculate Metrics----------------------------------------#

# Here we start testing the differences between GT and predicted masks

with open(QC_model_path+'/'+QC_model_name+'/Quality Control/QC_metrics_'+QC_model_name+".csv", "w", newline='') as file:
    writer = csv.writer(file, delimiter=",")
    stats_columns = ["image"]

    for l in range(labels):
        stats_columns.append("Prediction v. GT IoU label = {}".format(l))
    stats_columns.append("Prediction v. GT averaged IoU")
    writer.writerow(stats_columns)
    # Initialise the lists
    filename_list = []
    iou_score_list = []
    for filename in os.listdir(Source_QC_folder):
        if not os.path.isdir(os.path.join(Source_QC_folder, filename)):
            print('Running QC on: '+filename)
            test_input = io.imread(os.path.join(Source_QC_folder, filename), as_gray=True)
            test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, filename), as_gray=True)
            test_prediction = io.imread(os.path.join(prediction_QC_folder, prediction_prefix + filename))

            iou_labels = [filename]
            iou_score = 0.
            for l in range(labels):
                aux_gt = (test_ground_truth_image==l).astype(np.uint8)
                aux_pred = (test_prediction==l).astype(np.uint8)
                intersection = np.logical_and(aux_gt, aux_pred)
                union = np.logical_or(aux_gt, aux_pred)

                iou_labels.append(str(np.sum(intersection) / np.sum(union)))
                iou_score +=  np.sum(intersection) / np.sum(union)
            filename_list.append(filename)
            iou_score_list.append(iou_score/labels)
            iou_labels.append(str(iou_score/labels))
            writer.writerow(iou_labels)
    file.close()

## Create a display of the results

# Table with metrics as dataframe output
pdResults = pd.DataFrame(index = filename_list)
pdResults["IoU"] = iou_score_list

# ------------- For display ------------
print('--------------------------------------------------------------')
@interact
def show_QC_results(file=os.listdir(Source_QC_folder)):

  plt.figure(figsize=(25,5))
  #Input
  plt.subplot(1,4,1)
  plt.axis('off')
  plt.imshow(plt.imread(os.path.join(Source_QC_folder, file)), aspect='equal', cmap='gray', interpolation='nearest')
  plt.title('Input')

  #Ground-truth
  plt.subplot(1,4,2)
  plt.axis('off')
  test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, file),as_gray=True)
  plt.imshow(test_ground_truth_image, aspect='equal', cmap='Greens')
  plt.title('Ground Truth')

  #Prediction
  plt.subplot(1,4,3)
  plt.axis('off')
  test_prediction = plt.imread(os.path.join(prediction_QC_folder, prediction_prefix+file))
  plt.imshow(test_prediction, aspect='equal', cmap='Purples')
  plt.title('Prediction')

  #Overlay
  plt.subplot(1,4,4)
  plt.axis('off')
  plt.imshow(test_ground_truth_image, cmap='Greens')
  plt.imshow(test_prediction, alpha=0.5, cmap='Purples')
  metrics_title = 'Overlay (IoU: ' + str(round(pdResults.loc[file]["IoU"],3)) + ')'
  plt.title(metrics_title)
  plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)

qc_pdf_export()

pdResults.head()

## **5.3. Export your model into the BioImage Model Zoo format**
---
<font size = 4>This section exports the model into the [BioImage Model Zoo](https://bioimage.io/#/) format so it can be used directly with deepImageJ or Ilastik. The new files will be stored in the model folder specified at the beginning of Section 5.

<font size = 4>Once the cell is executed, you will find a new zip file with the name specified in `trained_model_name.bioimage.io.model`.

<font size = 4>To use it with deepImageJ, download it and install it suing DeepImageJ Install Model > Install from a local file.

<font size = 4>To try the model in ImageJ, go to Plugins > DeepImageJ > DeepImageJ Run, choose this model from the list and click on Test Model.

<font size = 4>The exported model contains an additional ImageJ macro (`Contours2InstanceSegmentation.ijm`) to obtain a unique 2D image with the different labels and also identify each independent object in the image.

<font size = 4>  More information at https://deepimagej.github.io/deepimagej/

In [None]:
# ------------- User input ------------
# information about the model
#@markdown ##Insert the information to document your model:
Trained_model_name    = "" #@param {type:"string"}
Trained_model_description = "" #@param {type:"string"}

#@markdown ###Author(s) - insert information separated by commas:
Trained_model_author_1 =  "Author 1 name, *Author 1 affiliation, *Author 1 email, *Author 1 ORCID, *Author 1 Github User" #@param {type:"string"}
Trained_model_author_2 =  "Author 2 name, *Author 2 affiliation, *Author 2 email, *Author 2 ORCID, *Author 2 Github User" #@param {type:"string"}

# @markdown - ###Model Packager:
packager_same_as_author = True #@param {type:"boolean"}
#@markdown - If not, please, provide the following information:
Trained_model_packager = "Packager name, *Packager affiliation, *Packager email, *Packager ORCID, *Packager Github User" #@param {type:"string"}

# @markdown ###Model Maintainer:
maintainer_same_as_author = True #@param {type:"boolean"}
#@markdown - If not, please, provide the following information:
Trained_model_maintainer = "Maintainer Github User, *Maintainer name, *Maintainer affiliation, *Maintainer email, *Maintainer ORCID" #@param {type:"string"}

#@markdown - License:
Trained_model_license = 'CC-BY-4.0' #@param {type:"string"}

Trained_model_references = ["Falk et al. Nature Methods 2019", "Ronneberger et al. arXiv in 2015", "Lucas von Chamier et al. biorXiv 2020"]
Trained_model_DOI = ["https://doi.org/10.1038/s41592-018-0261-2","https://doi.org/10.1007/978-3-319-24574-4_28", "https://doi.org/10.1101/2020.03.20.000133"]

# Training data
# ---------------------------------------
#@markdown ##Include information about training data (optional):
include_training_data = True #@param {type: "boolean"}
#@markdown ### - If it is published in the BioImage Model Zoo, please, provide the ID
data_from_bioimage_model_zoo = True #@param {type: "boolean"}
training_data_ID = ''#@param {type:"string"}
#@markdown ### - If not, please provide the URL to the data and a short description to be added to the README.md file
training_data_source = ''#@param {type:"string"}
training_data_description = ''#@param {type:"string"}

# Add input image information
# ---------------------------------------
#@markdown ##Indicate the minimum x/y size of the image (in pixels) and step size (in pixels) to be used for block/tiling:
# information about the example image
min_size = 64 #@param {type:"number"}
step_size = 16 #@param {type:"number"}
#@markdown ##Do you want to choose the example image?
default_example_image = True #@param {type:"boolean"}
#@markdown ###If not, please input:
fileID    =  "" #@param {type:"string"}
if default_example_image:
    fileID = os.path.join(Source_QC_folder, os.listdir(Source_QC_folder)[0])

# Load the model and process the example image
# ---------------------------------------
# Load the model
model = load_model(os.path.join(full_QC_model_path, 'weights_best.hdf5'),
                  custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})

# ------------- Execute bioimage model zoo configuration ------------
# Create a model without compilation so it can be used in any other environment.
# remove the custom loss function from the model, so that it can be used outside of this notebook
unet = Model(model.input, model.output)
weight_path = os.path.join(full_QC_model_path, 'keras_weights.hdf5')
unet.save(weight_path)

# training data source
if data_from_bioimage_model_zoo:
   training_data = {'id' : training_data_ID}
else:
   training_data = None 

# create the author/maintainer/packager spec input
author_1_spec = make_author(Trained_model_author_1)
author_2_spec = make_author(Trained_model_author_2)

authors = [author_1_spec]
# check if author 2 was filled
if author_2_spec.name != 'Author 2 name' and author_2_spec.name != '':
    authors.append(author_2_spec)

if packager_same_as_author:
    packager_spec = author_1_spec
else:
    packager_spec = make_author(Trained_model_packager)

if maintainer_same_as_author:
    if author_1_spec.github_user != None:
        maintainer_spec = author_1_spec
    else:
       print('Please, provide the author GitHub username in the author information')
else:
    maintainer_spec = make_maintainer(Trained_model_maintainer)


# I would recommend using CCBY-4 as licence
license = Trained_model_license

# where to save the model
output_root = os.path.join(full_QC_model_path, Trained_model_name + '.bioimage.io.model')
os.makedirs(output_root, exist_ok=True)
output_path = os.path.join(output_root, f"{Trained_model_name}.zip")

# create a markdown readme with information
documentation_path = os.path.join(output_root, "README.md")
with open(documentation_path, "w") as f:
  f.write("Visit https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki \n\n This was an automatically generated README.md. \n\n")

# create the citation input spec
assert len(Trained_model_DOI) == len(Trained_model_references)
citations = [{'text': text, 'doi': doi.replace('https://doi.org/', '')} for text, doi in zip(Trained_model_references, Trained_model_DOI)]
citation_spec = [bioimageio_spec.CiteEntry(**c) for c in citations]

# create the training data
if include_training_data:
    if data_from_bioimage_model_zoo:
      training_data = {"id": training_data_ID}
    else:
      with open(documentation_path, "a") as f:
        f.write(f'Training data: {training_data_source} \n\n and description: {training_data_description} \n\n')
        training_data = None
else:
    training_data = None


# load the input image, crop it if necessary, and save as numpy file
# The crop will be centered to get an image with some content.
input_img = io.imread(fileID, as_gray = True)
assert input_img.ndim == 2,'Example input image is not a 2D grayscale image. Please, provide a 2D grayscale image.'

shape = input_img.shape
test_img = input_img

x_size = int(test_img.shape[0])
x_size = x_size-int(shape[0]/2)

y_size = int(test_img.shape[1])
y_size = y_size-int(shape[1]/2)

test_img = test_img[x_size : x_size + shape[0],
                    y_size : y_size + shape[1]]
assert test_img.shape == shape[1:3], f"{test_img.shape}, {shape}"

# Save the test image
test_input_path = os.path.join(output_root, "test_input.npy")
np.save(test_input_path, test_img)

# run prediction on the input image and save the result as expected output
test_prediction = unet.predict(test_img)
test_prediction = np.squeeze(test_prediction)
assert test_prediction.ndim == 3

test_output_path = os.path.join(output_root, "test_output.npy")
np.save(test_output_path, test_prediction)

# create the channel names for the output
channel_names = []
shape_pred = test_prediction.shape
channel_pred_idx = shape_pred.index(min(shape_pred))
n_channels = shape_pred[channel_pred_idx]

for idx in range(n_channels):
   channel_names.append(f'channel{idx}')

# create the input tensor
input_tensor = bioimageio_spec.InputTensorDescr(id=bioimageio_spec.TensorId('input0'), 
                                             description= 'This is the test input tensor created from the example image.',
                                             axes=[bioimageio_spec.BatchAxis(id='batch', description='', type='batch', size=None), 
                                                   bioimageio_spec.ChannelAxis(id='channel', description='', type='channel', channel_names=['channel0']), 
                                                   bioimageio_spec.SpaceInputAxis(size=bioimageio_spec.ParameterizedSize(min=min_size, step=step_size), id='y', description='', type='space', unit=None, scale=1.0, concatenable=False), 
                                                   bioimageio_spec.SpaceInputAxis(size=bioimageio_spec.ParameterizedSize(min=min_size, step=step_size), id='x', description='', type='space', unit=None, scale=1.0, concatenable=False)],
                                             test_tensor = bioimageio_spec.FileDescr(source = test_input_path), 
                                               )


# create the output tensor
output_tensor = bioimageio_spec.OutputTensorDescr( axes=[bioimageio_spec.BatchAxis(id='batch', description='', type='batch', size=None), 
                                                      bioimageio_spec.ChannelAxis(id='channel', description='', type='channel', channel_names=channel_names), 
                                                      bioimageio_spec.SpaceOutputAxis(size=bioimageio_spec.SizeReference(tensor_id=bioimageio_spec.TensorId('input0'), axis_id='y', offset=0), id='y', description='', type='space', unit=None, scale=1.0), 
                                                      bioimageio_spec.SpaceOutputAxis( size=bioimageio_spec.SizeReference(tensor_id=bioimageio_spec.TensorId('input0'), axis_id='x', offset=0), id='x', description='', type='space', unit=None, scale=1.0)],
                                                      test_tensor = bioimageio_spec.FileDescr(source = test_output_path) )


# attach the QC report to the model (if it exists)
qc_path = os.path.join(full_QC_model_path, 'Quality Control', 'training_evaluation.csv')
if os.path.exists(qc_path):
  attachments = {"files": [qc_path]}
else:
  attachments = None

# Include a post-processing deepImageJ macro
macro = "Contours2InstanceSegmentation.ijm"
url = f"https://raw.githubusercontent.com/deepimagej/imagej-macros/master/{macro}"
path = os.path.join(output_root, macro)
with requests.get(url, stream=True) as r:
  text = r.text
  if text.startswith("4"):
      raise RuntimeError(f"An error occured when downloading {url}: {r.text}")
  with open(path, "w") as f:
      f.write(r.text)
attachments = {"files": attachments["files"] + [path]}

# make cover image
cover = test_img

for idx in range(n_channels):
    if channel_pred_idx == 0:
        cover = np.concatenate((cover, test_prediction[idx,:,:]), axis=1)
    elif channel_pred_idx == 1:
       cover = np.concatenate((cover, test_prediction[:,idx,:]), axis=1)
    elif channel_pred_idx == 2:
       cover = np.concatenate((cover, test_prediction[:,:,idx]), axis=1)

cover_path = os.path.join(output_root, "cover.png")
plt.imsave(cover_path, cover)

# make weights description
unet_tf_weights = bioimageio_spec.KerasHdf5WeightsDescr(source=weight_path, tensorflow_version=tf.__version__)
unet_weights = bioimageio_spec.WeightsDescr(keras_hdf5=unet_tf_weights)

# create model description for export
model_description = bioimageio_spec.ModelDescr(name=Trained_model_name,
                                                description=Trained_model_description,
                                                covers=[cover_path],
                                                authors=authors,
                                                attachments=attachments,
                                                cite=[citation_spec],
                                                license=license,

                                                maintainers=maintainer_spec,
                                                tags=['zerocostdl4mic', 'deepimagej', 'segmentation', 'unet'],
                                                documentation= documentation_path,
                                                inputs=[input_tensor],
                                                outputs=[output_tensor],
                                                packaged_by=packager_spec,
                                                weights=unet_weights,
                                                training_data=training_data,
                                                
                                              )


# test model
summary = bioimageio_core.test_model(model_description, weight_format="keras_hdf5")
summary.display()

success = summary.status == "passed" 

save_bioimageio_package(model_description, output_path=Path(output_path))
   
if success:
  print("The bioimage.io model was successfully exported to", output_path)
else:
  print("The bioimage.io model was exported to", output_path)
  print("Some tests of the model did not work! You can still download and test the model.")
  print("You can still download and test the model, but it may not work as expected.")

# **6. Using the trained model**

---
<font size = 4>In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive.

## **6.1 Generate prediction(s) from unseen dataset**
---

<font size = 4>The current trained model (from section 4.1) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder.

<font size = 4>**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.

<font size = 4>**`Result_folder`:** This folder will contain the predicted output images.

<font size = 4> Once the predictions are complete the cell will display a random example prediction beside the input image and the calculated mask for visual inspection.

<font size = 4> **Troubleshooting:** If there is a low contrast image warning when saving the images, this may be due to overfitting of the model to the data. It may result in images containing only a single colour. Train the network again with different network hyperparameters.

In [None]:


# ------------- Initial user input ------------
#@markdown ###Provide the path to your dataset and to the folder where the predicted masks will be saved (Result folder), then play the cell to predict the output on your unseen images and store it.
Data_folder = '' #@param {type:"string"}
Results_folder = '' #@param {type:"string"}

#@markdown ###Do you want to use the current trained model?
Use_the_current_trained_model = True #@param {type:"boolean"}

#@markdown ###If not, please provide the path to the model folder:

Prediction_model_folder = "" #@param {type:"string"}

#Here we find the loaded model name and parent path
Prediction_model_name = os.path.basename(Prediction_model_folder)
Prediction_model_path = os.path.dirname(Prediction_model_folder)


# ------------- Failsafes ------------
if (Use_the_current_trained_model):
  print("Using current trained network")
  Prediction_model_name = model_name
  Prediction_model_path = model_path

full_Prediction_model_path = os.path.join(Prediction_model_path, Prediction_model_name)
if os.path.exists(full_Prediction_model_path):
  print("The "+Prediction_model_name+" network will be used.")
else:
  print(R+'!! WARNING: The chosen model does not exist !!'+W)
  print('Please make sure you provide a valid model path and model name before proceeding further.')


# ------------- Prepare the model and run predictions ------------

# Load the model and prepare generator



unet = load_model(os.path.join(Prediction_model_path, Prediction_model_name, 'weights_best.hdf5'), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})
Input_size = unet.layers[0].output_shape[0][1:3]
print('Model input size: '+str(Input_size[0])+'x'+str(Input_size[1]))

# Create a list of sources
source_dir_list = os.listdir(Data_folder)
number_of_dataset = len(source_dir_list)
print('Number of dataset found in the folder: '+str(number_of_dataset))

predictions = []
for i in tqdm(range(number_of_dataset)):
  predictions.append(predict_as_tiles(os.path.join(Data_folder, source_dir_list[i]), unet))
  # predictions.append(prediction(os.path.join(Data_folder, source_dir_list[i]), os.path.join(Prediction_model_path, Prediction_model_name)))


# Save the results in the folder along with the masks according to the set threshold
saveResult(Results_folder, predictions, source_dir_list, prefix=prediction_prefix)


# ------------- For display ------------
print('--------------------------------------------------------------')


def show_prediction_mask(file=os.listdir(Data_folder)):

  plt.figure(figsize=(10,6))
  # Wide-field
  plt.subplot(1,2,1)
  plt.axis('off')
  img_Source = plt.imread(os.path.join(Data_folder, file))
  plt.imshow(img_Source, cmap='gray')
  plt.title('Source image',fontsize=15)
  # Prediction
  plt.subplot(1,2,2)
  plt.axis('off')
  img_Prediction = plt.imread(os.path.join(Results_folder, prediction_prefix+file))
  plt.imshow(img_Prediction, cmap='gray')
  plt.title('Prediction',fontsize=15)

interact(show_prediction_mask);



## **6.2. Download your predictions**
---

<font size = 4>**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name.

# **7. Version log**

---
<font size = 4>**v2.1.3**:  

*    Updated Bioimage.io model export to latest version (core-0.6.8, spec-0.5.3.2)
*    Bug fixes

<font size = 4>**v2.1.2**:  

*    Correct for data loading to avoid .DS_Store or similar

<font size = 4>**v2.1.1**:  

*    Replaced all absolute pathing with relative pathing

<font size = 4>**v2.1**:
* Updated to TensorFlow 2.11
* Updated to `fpdf2` and add lines to ensure a proper format. Correct keras package version parsing.
---

# **Thank you for using 2D U-Net multilabel segmentation!**