# **2D segmentation with a U-Net**
---
#### Bioimage analysis course, Gothenburg University, September 2022
#### Daniel Sage, EPFL, Lausanne, Switzerland
#### Anaïs Badoual, Inria, France
---

This notebook is a modified version of the **U-Net 2D** notebook of the Zero-CostDL4Mic (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). This goal of this notebook is to train a neural network U-net for an image  segmentation task by a pixel classification (only two classes, here). The model is stored in the Bioimage Model Zoo format and directly usable in deepImageJ. 

---
**References**
- [U-Net](https://arxiv.org/abs/1505.04597) Ronneberger et al. MICCAI 2015.
- [ZeroCostDL4Mic](https://www.nature.com/articles/s41467-021-22518-0): L. von Chamier et al., Nature Methods, 2021. Developed by the [G. Jacquemet Lab](https://cellmig.org/) and [R. Henriques Lab](https://henriqueslab.github.io/).
- [Bioimage Model Zoo (BMZ)](https://www.biorxiv.org/content/10.1101/2022.06.07.495102v1) W. Ouyang et al., biorxiv  2022.
- [BMZ exporter](https://github.com/esgomezm): E. Gómez-de-Mariscal 2021.
- [DeepImageJ](https://doi.org/10.1038/s41592-021-01262-9): E. Gómez-de-Mariscal et al., Nature Methods 2021. deepImage team

*Please also cite this original paper when using or developing this notebook.* 

# **1. Initialization**

To clean all outputs, use the command of the menu: Edit -> Clear all outputs



In [None]:
#@markdown ### **1.1 Install TensorFlow 1.15 and other libraries**
# New running TF1.15 on Colab 2022
!pip uninstall -y -q tensorflow
!pip install -q tensorflow-probability==0.8
!pip install -q kapre==0.1.7
!pip install -q tensorflow==1.15
# Install packages which are not included in Google Colab

!pip install data
!pip install -q tifffile # contains tools to operate tiff-files
#!pip install edt # improves STARDIST performances
!pip install -q wget
!pip install -q fpdf
!pip install -q PTable # Nice tables 
!pip install -q zarr
!pip install -q imagecodecs
!pip install h5py==2.10
!pip install "bioimageio.core>=0.5,<0.6"
!pip install pyyaml==5.4.1
#!pip uninstall -y -q keras-nightly

!pip install codecarbon
#Force session restart
exit(0)

from IPython. display import clear_output
clear_output(wait=False)

print("The required packages are installed.")


Ignore the following message error message. Your Runtime has automatically restarted. This is normal.

<img width="40%" alt ="" src="https://github.com/HenriquesLab/ZeroCostDL4Mic/raw/master/Wiki_files/session_crash.png"><figcaption>  </figcaption>

In [None]:
from __future__ import print_function

# Suppressing some warnings
import warnings
warnings.filterwarnings('ignore')

from builtins import any as b_any
import sys
before = [str(m) for m in sys.modules]

Notebook_version = "1.2"
#@markdown ### **1.2 Import packages and define functions**

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

#%tensorflow_version 1.x
import tensorflow as tf
print('TensorFlow Version: ', tf.__version__)
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)


# Keras imports
from tensorflow.keras import models
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, concatenate, UpSampling2D
from tensorflow.keras.optimizers import Adam
# from keras.callbacks import ModelCheckpoint, LearningRateScheduler, CSVLogger # we currently don't use any other callbacks from ModelCheckpoints
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img
from tensorflow.keras import backend as keras

# General import
import numpy as np
import pandas as pd
import glob
from skimage import img_as_ubyte, img_as_float32, io, transform
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib.pyplot import imread
from pathlib import Path
import shutil
import random
import time
import csv
import sys
from math import ceil
from fpdf import FPDF, HTMLMixin
from pip._internal.operations.freeze import freeze
import subprocess

# Imports for QC
from PIL import Image
from scipy import signal
from scipy import ndimage
from sklearn.linear_model import LinearRegression
from skimage.util import img_as_uint
from skimage.metrics import structural_similarity
from skimage.metrics import peak_signal_noise_ratio as psnr

# For sliders and dropdown menu and progress bar
from ipywidgets import interact
import ipywidgets as widgets
# from tqdm import tqdm
from tqdm.notebook import tqdm

from sklearn.feature_extraction import image
from skimage import img_as_ubyte, io, transform
from skimage.util.shape import view_as_windows
from datetime import datetime

# BioImage Model Zoo dependencies
from bioimageio.core import load_raw_resource_description, load_resource_description
from zipfile import ZipFile
from shutil import rmtree
from bioimageio.core.build_spec import build_model, add_weights
from bioimageio.core.resource_tests import test_model
from bioimageio.core.weight_converter.keras import convert_weights_to_tensorflow_saved_model_bundle
import requests

# Deepiction
from codecarbon import EmissionsTracker
from google.colab import data_table

Network = "U-net"

def get_requirements_path():
    # Store requirements file in 'contents' directory 
    current_dir = os.getcwd()
    dir_count = current_dir.count('/') - 1
    path = '../' * (dir_count) + 'requirements.txt'
    return path

def filter_files(file_list, filter_list):
    filtered_list = []
    for fname in file_list:
        if b_any(fname.split('==')[0] in s for s in filter_list):
            filtered_list.append(fname)
    return filtered_list

def build_requirements_file(before, after):
    path = get_requirements_path()

    # Exporting requirements.txt for local run
    !pip freeze > $path

    # Get minimum requirements file
    df = pd.read_csv(path, delimiter = "\n")
    mod_list = [m.split('.')[0] for m in after if not m in before]
    req_list_temp = df.values.tolist()
    req_list = [x[0] for x in req_list_temp]

    # Replace with package name and handle cases where import name is different to module name
    mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]
    mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] 
    filtered_list = filter_files(req_list, mod_replace_list)

    file=open(path,'w')
    for item in filtered_list:
        file.writelines(item + '\n')

    file.close()

from builtins import any as b_any

def getClassWeights(Training_target_path):

    


  Mask_dir_list = os.listdir(Training_target_path)
  number_of_dataset = len(Mask_dir_list)
  #print("number of mask:", Mask_dir_list)
  class_count = np.zeros(2, dtype=int)
  #for i in tqdm(range(number_of_dataset)):
  for i in range(number_of_dataset):
    file = Mask_dir_list[i]
    if file.startswith('.'): 
      print('Rejected ', file)
    elif file.endswith('.tif') or file.endswith('.tiff') or file.endswith('.png'): 
      mask = io.imread(os.path.join(Training_target_path, file))
      mask = normalizeMinMax(mask)
      class_count[0] += mask.shape[0]*mask.shape[1] - mask.sum()
      class_count[1] += mask.sum()

  n_samples = class_count.sum()
  n_classes = 2
  class_weights = n_samples / (n_classes * class_count)
  return class_weights

def weighted_binary_crossentropy(class_weights):

    def _weighted_binary_crossentropy(y_true, y_pred):
        binary_crossentropy = keras.binary_crossentropy(y_true, y_pred)
        weight_vector = y_true * class_weights[1] + (1. - y_true) * class_weights[0]
        weighted_binary_crossentropy = weight_vector * binary_crossentropy
        return keras.mean(weighted_binary_crossentropy)

    return _weighted_binary_crossentropy

def save_augment(datagen,orig_img,dir_augmented_data="/content/augment"):
  """
  Saves a subset of the augmented data for visualisation, by default in /content.

  This is adapted from: https://fairyonice.github.io/Learn-about-ImageDataGenerator.html
  
  """
  try:
    os.mkdir(dir_augmented_data)
  except:
        ## if the preview folder exists, then remove
        ## the contents (pictures) in the folder
    for item in os.listdir(dir_augmented_data):
      os.remove(dir_augmented_data + "/" + item)

    ## convert the original image to array
  x = img_to_array(orig_img)
    ## reshape (Sampke, Nrow, Ncol, 3) 3 = R, G or B
    #print(x.shape)
  x = x.reshape((1,) + x.shape)
    #print(x.shape)
    ## -------------------------- ##
    ## randomly generate pictures
    ## -------------------------- ##
  i = 0
    #We will just save 5 images,
    #but this can be changed, but note the visualisation in 3. currently uses 5.
  Nplot = 5
  for batch in datagen.flow(x,batch_size=1,
                            save_to_dir=dir_augmented_data,
                            save_format='tif',
                            seed=42):
    i += 1
    if i > Nplot - 1:
      break

# Generators
def buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, subset, batch_size, target_size):
  '''
  Can generate image and mask at the same time use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same
  
  datagen: ImageDataGenerator 
  subset: can take either 'training' or 'validation'
  '''
  seed = 1
  image_generator = image_datagen.flow_from_directory(
      os.path.dirname(image_folder_path),
      classes = [os.path.basename(image_folder_path)],
      class_mode = None,
      color_mode = "grayscale",
      target_size = target_size,
      batch_size = batch_size,
      subset = subset,
      interpolation = "bicubic",
      seed = seed)
  
  mask_generator = mask_datagen.flow_from_directory(
      os.path.dirname(mask_folder_path),
      classes = [os.path.basename(mask_folder_path)],
      class_mode = None,
      color_mode = "grayscale",
      target_size = target_size,
      batch_size = batch_size,
      subset = subset,
      interpolation = "nearest",
      seed = seed)
  
  this_generator = zip(image_generator, mask_generator)
  for (img,mask) in this_generator:
      # img,mask = adjustData(img,mask)
      yield (img,mask)

def prepareGenerators(image_folder_path, mask_folder_path, datagen_parameters, batch_size = 4, target_size = (512, 512)):
  image_datagen = ImageDataGenerator(**datagen_parameters, preprocessing_function = normalizePercentile)
  mask_datagen = ImageDataGenerator(**datagen_parameters, preprocessing_function = normalizeMinMax)
  train_datagen = buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, 'training', batch_size, target_size)
  validation_datagen = buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, 'validation', batch_size, target_size)
  return (train_datagen, validation_datagen)

# Normalization functions from Martin Weigert
def normalizePercentile(x, pmin=1, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):
    """This function is adapted from Martin Weigert"""
    """Percentile-based image normalization."""

    mi = np.percentile(x,pmin,axis=axis,keepdims=True)
    ma = np.percentile(x,pmax,axis=axis,keepdims=True)
    return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)


def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32
    """This function is adapted from Martin Weigert"""
    if dtype is not None:
        x   = x.astype(dtype,copy=False)
        mi  = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)
        ma  = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)
        eps = dtype(eps)

    try:
        import numexpr
        x = numexpr.evaluate("(x - mi) / ( ma - mi + eps )")
    except ImportError:
        x =                   (x - mi) / ( ma - mi + eps )

    if clip:
        x = np.clip(x,0,1)

    return x

# Simple normalization to min/max fir the Mask
def normalizeMinMax(x, dtype=np.float32):
  x = x.astype(dtype,copy=False)
  x = (x - np.amin(x)) / (np.amax(x) - np.amin(x))
  return x

# This is code outlines the architecture of U-net. The choice of pooling steps decides the depth of the network. 
def unet(pretrained_weights = None, nchannels=64, input_size = (256,256,1), pooling_steps = 4, learning_rate = 1e-4, verbose=True, class_weights=np.ones(2)):
    n = nchannels
    inputs = Input(input_size)
    conv1 = Conv2D(n, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)
    conv1 = Conv2D(n, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
    # Downsampling steps
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = Conv2D(2*n, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
    conv2 = Conv2D(2*n, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
    
    if pooling_steps > 1:
      pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
      conv3 = Conv2D(4*n, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
      conv3 = Conv2D(4*n, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)

      if pooling_steps > 2:
        pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
        conv4 = Conv2D(8*n, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
        conv4 = Conv2D(8*n, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
        drop4 = Dropout(0.5)(conv4)
      
        if pooling_steps > 3:
          pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
          conv5 = Conv2D(16*n, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
          conv5 = Conv2D(16*n, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
          drop5 = Dropout(0.5)(conv5)
          up6 = Conv2D(8*n, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))
          merge6 = concatenate([drop4,up6], axis = 3)
          conv6 = Conv2D(8*n, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
          conv6 = Conv2D(8*n, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)
          
    if pooling_steps > 2:
      up7 = Conv2D(4*n, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop4))
      if pooling_steps > 3:
        up7 = Conv2D(4*n, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
      merge7 = concatenate([conv3,up7], axis = 3)
      conv7 = Conv2D(4*n, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
      conv7 = Conv2D(4*n, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)

    if pooling_steps > 1:
      up8 = Conv2D(2*n, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv3))
      if pooling_steps > 2:
        up8 = Conv2D(2*n, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
      merge8 = concatenate([conv2,up8], axis = 3)
      conv8 = Conv2D(2*n, 3, activation= 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
      conv8 = Conv2D(2*n, 3, activation= 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)
      
    if pooling_steps == 1:
      up9 = Conv2D(n, 2, padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv2))
    else:
      up9 = Conv2D(n, 2, padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8)) #activation = 'relu'
    
    merge9 = concatenate([conv1,up9], axis = 3)
    conv9 = Conv2D(n, 3, padding = 'same', kernel_initializer = 'he_normal')(merge9) #activation = 'relu'
    conv9 = Conv2D(n, 3, padding = 'same', kernel_initializer = 'he_normal')(conv9) #activation = 'relu'
    conv9 = Conv2D(2, 3, padding = 'same', kernel_initializer = 'he_normal')(conv9) #activation = 'relu'
    conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9)

    model = Model(inputs = inputs, outputs = conv10)

    # model.compile(optimizer = Adam(lr = learning_rate), loss = 'binary_crossentropy', metrics = ['acc'])
    model.compile(optimizer = Adam(lr = learning_rate), loss = weighted_binary_crossentropy(class_weights))

    if(pretrained_weights):
    	model.load_weights(pretrained_weights);

    return model

def predict_as_tiles(Image_path, model):

  # Read the data in and normalize
  Image_raw = io.imread(Image_path, as_gray = True)
  Image_raw = normalizePercentile(Image_raw)

  # Get the patch size from the input layer of the model
  patch_size = model.input_shape[1:3]

  # Pad the image with zeros if any of its dimensions is smaller than the patch size
  if Image_raw.shape[0] < patch_size[0] or Image_raw.shape[1] < patch_size[1]:
    Image = np.zeros((max(Image_raw.shape[0], patch_size[0]), max(Image_raw.shape[1], patch_size[1])))
    Image[0:Image_raw.shape[0], 0: Image_raw.shape[1]] = Image_raw
  else:
    Image = Image_raw

  # Calculate the number of patches in each dimension
  n_patch_in_width = ceil(Image.shape[0]/patch_size[0])
  n_patch_in_height = ceil(Image.shape[1]/patch_size[1])

  prediction = np.zeros(Image.shape)

  for x in range(n_patch_in_width):
    for y in range(n_patch_in_height):
      xi = patch_size[0]*x
      yi = patch_size[1]*y

      # If the patch exceeds the edge of the image shift it back 
      if xi+patch_size[0] >= Image.shape[0]:
        xi = Image.shape[0]-patch_size[0]

      if yi+patch_size[1] >= Image.shape[1]:
        yi = Image.shape[1]-patch_size[1]
      
      # Extract and reshape the patch
      patch = Image[xi:xi+patch_size[0], yi:yi+patch_size[1]]
      patch = np.reshape(patch,patch.shape+(1,))
      patch = np.reshape(patch,(1,)+patch.shape)

      # Get the prediction from the patch and paste it in the prediction in the right place
      predicted_patch = model.predict(patch, batch_size = 1)
      prediction[xi:xi+patch_size[0], yi:yi+patch_size[1]] = np.squeeze(predicted_patch)


  return prediction[0:Image_raw.shape[0], 0: Image_raw.shape[1]]
  
def saveResult(save_path, nparray, source_dir_list, prefix='', threshold=None):
  for (filename, image) in zip(source_dir_list, nparray):
      io.imsave(os.path.join(save_path, prefix+os.path.splitext(filename)[0]+'.tif'), img_as_ubyte(image)) # saving as unsigned 8-bit image
      
      # For masks, threshold the images and return 8 bit image
      if threshold is not None:
        mask = convert2Mask(image, threshold)
        io.imsave(os.path.join(save_path, prefix+'mask_'+os.path.splitext(filename)[0]+'.tif'), mask)


def convert2Mask(image, threshold):
  mask = img_as_ubyte(image, force_copy=True)
  mask[mask > threshold] = 255
  mask[mask <= threshold] = 0
  return mask


def getIoUvsThreshold(prediction_filepath, groud_truth_filepath):
  prediction = io.imread(prediction_filepath)
  ground_truth_image = img_as_ubyte(io.imread(groud_truth_filepath, as_gray=True), force_copy=True)

  threshold_list = []
  IoU_scores_list = []

  for threshold in range(0,256): 
    # Convert to 8-bit for calculating the IoU
    mask = img_as_ubyte(prediction, force_copy=True)
    mask[mask > threshold] = 255
    mask[mask <= threshold] = 0

    # Intersection over Union metric
    intersection = np.logical_and(ground_truth_image, np.squeeze(mask))
    union = np.logical_or(ground_truth_image, np.squeeze(mask))
    iou_score = np.sum(intersection) / np.sum(union)

    threshold_list.append(threshold)
    IoU_scores_list.append(iou_score)

  return (threshold_list, IoU_scores_list)

prediction_prefix = 'Predicted_'
print()
print('The requires function are imported or defined.')

# Build requirements file for local run
after = [str(m) for m in sys.modules]
build_requirements_file(before, after)

# Allow TensorBoard
#%load_ext tensorboard
#!rm -rf ./logs/
#log_dir = "logs/fit/" + datetime.now().strftime("%Y%m%d-%H%M%S")
#tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)


# **2. Google Colab Setting Up**

In [None]:
#@markdown ### **2.1. Check GPU**
if tf.test.gpu_device_name()=='':
  print('You do not have GPU access.') 
  print('Did you change your runtime ?') 
  print('If the runtime setting is correct then Google did not allocate a GPU for your session')
  print('Expect slow performance. To access GPU try reconnecting later')

else:
  print('You have GPU access')
  !nvidia-smi

# from tensorflow.python.client import device_lib 
# device_lib.list_local_devices()

# print the tensorflow version
print('Tensorflow version is ' + str(tf.__version__))


In [None]:
#@markdown ### **2.2 Connect to your Google Drive**

#@markdown If you cannot see your files, reactivate your session by connecting to your hosted runtime.

# mount user's Google Drive to Google Colab.
from google.colab import drive
drive.mount('/content/gdrive')



# **3. Data**

In [None]:
#@markdown ### **3.1 Define path to the dataset**
#@markdown **`dataset_path`:** should contains 2 folders **training_source** and **training_target** for the training or fine tuning (and optionally 2 other folders for the quality control **test_source** and **test_target**).
Dataset_path = '' #@param {type:"string"}

#@markdown **`percentage_validation`:**  Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** 
percentage_validation =  10 #@param{type:"number"}

for f in os.listdir(Dataset_path):
    if f.lower() == 'training_source' or f.lower() == 'source' or f.lower() == 'train_source':
      Training_source = os.path.join(Dataset_path, f)
    if f.lower() == 'training_target' or f.lower() == 'target' or f.lower() == 'train_target':
      Training_target = os.path.join(Dataset_path, f)
    if f.lower() == 'test_source' or f.lower() == 'qc_source':
      Source_QC = os.path.join(Dataset_path, f)
    if f.lower() == 'test_target' or f.lower() == 'qc_target':
      Target_QC = os.path.join(Dataset_path, f)


print(os.listdir(Dataset_path))
print('Number of files in train_source', len(os.listdir(Training_source)))
print('Number of files in train_target', len(os.listdir(Training_target)))
print('Number of files in test_source', len(os.listdir(Source_QC)))
print('Number of files in test_target', len(os.listdir(Target_QC)))

file = os.listdir(Training_source)
img = io.imread(os.path.join(Training_source, file[0]))
patch_width = img.shape[0]
patch_height = img.shape[1]

print('Size of image:', patch_width, ' x ', patch_height)

In [None]:
#@markdown ### **3.2 Read data**

#@markdown Read the .tif, .tiff or .png files.

def read_data(Training_source, Training_target, min_fraction):
  """
  min_fraction is the minimum fraction of pixels that need to be foreground to be considered as a valid patch
  Returns: - Two paths to where the patches are now saved
  """
  Patch_source = os.path.join('/content','img_patches')
  Patch_target = os.path.join('/content','mask_patches')
  if os.path.exists(Patch_source):
    shutil.rmtree(Patch_source)
  if os.path.exists(Patch_target):
    shutil.rmtree(Patch_target)
  os.mkdir(Patch_source)
  os.mkdir(Patch_target)
  patch_num = 0
  table_files = []
  table_dim = []
  table_image = []
  table_mask = []
  table_classes = []
  table_area = []
  table_index = []
  index_i = 0
  for file in tqdm(os.listdir(Training_source)):
    if file.startswith('.'): 
      print('Rejected ', file)
    elif file.endswith('.tif') or file.endswith('.tiff') or file.endswith('.png'):  
      img = io.imread(os.path.join(Training_source, file))
      patch_width =  img.shape[0]
      patch_height =  img.shape[1]
      area = patch_width * patch_height
      mask = io.imread(os.path.join(Training_target, file),as_gray=True)
      area = mask.shape[0] * mask.shape[1]
      unique, counts = np.unique(mask, return_counts=True)
      index_i = index_i +1;
      table_files.append(file)
      table_dim.append(str(patch_width)  + "x" +str(patch_height))
      table_image.append(str(img.shape) + '  [' + str(np.min(img))+ ', ' + str(np.max(img)) + ']')
      table_mask.append(str(mask.shape) + '  [' + str(np.min(mask))+ ', ' + str(np.max(mask)) + ']')
      table_area.append(str(unique) + ' ' + str(np.round(100*counts/area, 2)))
      table_index.append(index_i)
      result = np.column_stack((unique, counts)) 
      patches_img = view_as_windows(img, (patch_width, patch_height), (patch_width, patch_height))
      patches_mask = view_as_windows(mask, (patch_width, patch_height), (patch_width, patch_height))
      patches_img = patches_img.reshape(patches_img.shape[0]*patches_img.shape[1], patch_width,patch_height)
      patches_mask = patches_mask.reshape(patches_mask.shape[0]*patches_mask.shape[1], patch_width,patch_height)
 
      for i in range(patches_img.shape[0]):
        img_save_path = os.path.join(Patch_source,'image_'+str(patch_num)+'.tif')
        mask_save_path = os.path.join(Patch_target,'image_'+str(patch_num)+'.tif')
        patch_num += 1
        pixel_threshold_array = sorted(patches_mask[i].flatten())
        if pixel_threshold_array[int(round(len(pixel_threshold_array)*(1-min_fraction)))]>0:
          io.imsave(img_save_path, img_as_ubyte(normalizeMinMax(patches_img[i])))
          io.imsave(mask_save_path, convert2Mask(normalizeMinMax(patches_mask[i]),0))
        else:
          print('Rejected min-frac ', file, ' less than ', min_fraction)
  list_of_files = pd.DataFrame({
          'Filename': table_files,
          'Source (dim, type, min, max)': table_image,
          'Target (dim, type, min, max)': table_mask,
          'Target (classes / area)': table_area
          }, index = table_index)

  return Patch_source, Patch_target, list_of_files


# The minimum fraction of pixels being foreground for a selected patch to be considered valid is set to the default value: 2%.
min_fraction = 0.02
Patch_source, Patch_target, list_of_files = read_data(Training_source, Training_target, min_fraction)

data_table.DataTable(list_of_files, num_rows_per_page=10)

number_of_training_dataset = len(os.listdir(Patch_source))
print('Total number of valid images: '+str(number_of_training_dataset))


In [None]:
#@markdown ### **3.3 List training images (optional)**

data_table.DataTable(list_of_files, num_rows_per_page=25)

In [None]:
#@markdown ### **3.4 Display source image and target (optional)**

#random_choice = random.choice(os.listdir(Patch_source))
#x = io.imread(os.path.join(Patch_source, random_choice))
#y = io.imread(os.path.join(Patch_target, random_choice), as_gray=True)
#f=plt.figure(figsize=(16,8))
#plt.subplot(1,2,1)
#plt.imshow(x, interpolation='nearest',cmap='gray')
#plt.title('Training image patch')
#plt.axis('off');
#plt.subplot(1,2,2)
#plt.imshow(y, interpolation='nearest',cmap='gray')
#plt.title('Training mask patch')
#plt.axis('off');
#plt.savefig('/content/TrainingDataExample_Unet2D.png',bbox_inches='tight',pad_inches=0)

data_table.DataTable(list_of_files, num_rows_per_page=25)


# ------------- For display ------------
@interact
def show_QC_results(file=os.listdir(Patch_source)):
  plt.figure(figsize=(16,8))
  #Input
  plt.subplot(1,2,1)
  plt.axis('off')
  plt.imshow(plt.imread(os.path.join(Patch_source, file)), aspect='equal', cmap='gray', interpolation='nearest')
  plt.title('Source')
  plt.subplot(1,2,2)
  plt.axis('off')
  test_ground_truth_image = io.imread(os.path.join(Patch_target, file),as_gray=True)
  plt.imshow(test_ground_truth_image, aspect='equal', cmap='gray')
  plt.title('Target')

In [None]:
#@markdown ### **3.5 Data Augmentation (Recommended)**

Use_Data_augmentation = True #@param {type:"boolean"}

if Use_Data_augmentation:
    horizontal_shift =  6 #@param {type:"slider", min:0, max:100, step:1}
    vertical_shift =  10 #@param {type:"slider", min:0, max:100, step:1}
    zoom_range =  10 #@param {type:"slider", min:0, max:100, step:1}
    shear_range =  10 #@param {type:"slider", min:0, max:100, step:1}
    horizontal_flip = True #@param {type:"boolean"}
    vertical_flip = True #@param {type:"boolean"}
    rotation_range =  180 #@param {type:"slider", min:0, max:180, step:1}
else:
  horizontal_shift =  0 
  vertical_shift =  0 
  zoom_range =  0
  shear_range =  0
  horizontal_flip = False
  vertical_flip = False
  rotation_range =  0

# Build the dict for the ImageDataGenerator
data_gen_args = dict(width_shift_range = horizontal_shift/100.,
                     height_shift_range = vertical_shift/100.,
                     rotation_range = rotation_range, #90
                     zoom_range = zoom_range/100.,
                     shear_range = shear_range/100.,
                     horizontal_flip = horizontal_flip,
                     vertical_flip = vertical_flip,
                     validation_split = percentage_validation/100,
                     fill_mode = 'reflect')

# ------------- Display ------------
dir_augmented_data_imgs="/content/augment_img"
dir_augmented_data_masks="/content/augment_mask"
random_choice = random.choice(os.listdir(Patch_source))
orig_img = load_img(os.path.join(Patch_source,random_choice))
orig_mask = load_img(os.path.join(Patch_target,random_choice))

augment_view = ImageDataGenerator(**data_gen_args)

if Use_Data_augmentation:
  print("Parameters enabled")
  print("Here is what a subset of your augmentations looks like:")
  save_augment(augment_view, orig_img, dir_augmented_data=dir_augmented_data_imgs)
  save_augment(augment_view, orig_mask, dir_augmented_data=dir_augmented_data_masks)

  fig = plt.figure(figsize=(15, 7))
  fig.subplots_adjust(hspace=0.0,wspace=0.1,left=0,right=1.1,bottom=0, top=0.8)

  ax = fig.add_subplot(2, 6, 1,xticks=[],yticks=[])        
  new_img=img_as_ubyte(normalizeMinMax(img_to_array(orig_img)))
  ax.imshow(new_img)
  ax.set_title('Original Image')
  i = 2
  for imgnm in os.listdir(dir_augmented_data_imgs):
    ax = fig.add_subplot(2, 6, i,xticks=[],yticks=[]) 
    img = load_img(dir_augmented_data_imgs + "/" + imgnm)
    ax.imshow(img)
    i += 1

  ax = fig.add_subplot(2, 6, 7,xticks=[],yticks=[])        
  new_mask=img_as_ubyte(normalizeMinMax(img_to_array(orig_mask)))
  ax.imshow(new_mask)
  ax.set_title('Original Mask')
  j=2
  for imgnm in os.listdir(dir_augmented_data_masks):
    ax = fig.add_subplot(2, 6, j+6,xticks=[],yticks=[]) 
    mask = load_img(dir_augmented_data_masks + "/" + imgnm)
    ax.imshow(mask)
    j += 1
  plt.show()

else:
  print("No augmentation will be used")

# **4. U-Net Configuration**

In [None]:
#@markdown ### **4.1 Define name and hyper-parameters**

model_path = "" #@param {type:"string"}
model_name = '' #@param {type:"string"}
number_of_epochs =  5#@param {type:"number"}

Use_pretrained_model = False
#@markdown **Advanced parameters - experienced users only**

#@markdown **`nchannels`** This is the number of filters of the UNet at the first scale **Default value: 64**
nchannels = 64 #@param {type:"integer"}
#@markdown **`batch_size`**: This parameter describes the amount of images that are loaded into the network per step. Smaller batchsizes may improve training performance slightly but may increase training time.  **Default: 4**
batch_size =  4 #@param {type:"integer"}
number_of_steps =  0
#@markdown **`pooling_steps`**: Each additional pooling step will also two additional convolutions. The network can learn more complex information but is also more likely to overfit. Achieving best performance may require testing different values here. **Default: 2**
pooling_steps = 3 #@param [1,2,3,4]{type:"raw"}
#@markdown **`initial_learning_rate`:**  Input the initial value to be used as learning rate. **Default value: 0.0003**
initial_learning_rate = 0.0003 #@param {type:"number"}

#  Create the folders where to save the model and the QC
full_model_path = os.path.join(model_path, model_name)

print('Model path: ', full_model_path)
# Here we disable pre-trained model by default (in case the next cell is not ran)
Use_pretrained_model = False
# Here we disable data augmentation by default (in case the cell is not ran)
# Use_Data_augmentation = False
# Build the default dict for the ImageDataGenerator
# data_gen_args = dict(width_shift_range = 0.,
#                      height_shift_range = 0.,
#                      rotation_range = 0., #90
#                      zoom_range = 0.,
#                      shear_range = 0.,
#                      horizontal_flip = False,
#                      vertical_flip = False,
#                      validation_split = percentage_validation/100,
#                      fill_mode = 'reflect')


In [None]:
#@markdown ### **4.2 Fine-tuning (optional)**

#@markdown Loading weights from a pre-trained network. This is not required to train a network from scratch

Use_pretrained_model = False #@param {type:"boolean"}
pretrained_model_choice = "Model_from_file" #@param ["Model_from_file", "bioimageio_model"]
Weights_choice = "best" #@param ["last", "best"]

#@markdown If you chose "Model_from_file", please provide the path to the model folder and the model id:
pretrained_model_path = "" #@param {type:"string"}
bioimageio_model_id = "" #@param {type:"string"}

# --------------------- Check if we load a previously trained model ------------------------
if Use_pretrained_model:

# --------------------- Load the model from the choosen path ------------------------
  if pretrained_model_choice == "Model_from_file":
    h5_file_path = os.path.join(pretrained_model_path, "weights_"+Weights_choice+".hdf5")
    qc_path = os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')

# --------------------- Load the model from a bioimageio model (can be path on drive or url / doi) ---
  elif pretrained_model_choice == "bioimageio_model":
    biomodel = load_resource_description(bioimageio_model_id)  
    if "keras_hdf5" not in biomodel.weights:
      print("Invalid bioimageio model")
      h5_file_path = "no-model"
      qc_path = "no-qc"
    else:
      h5_file_path = "/content/bioimage_model_zoo/"
      %rm -rf $h5_file_path
      
      h5_file_path = str(biomodel.weights["keras_hdf5"].source)
      try:
        attachments = biomodel.attachments.files
        qc_path = [fname for fname in attachments if fname.endswith("training_evaluation.csv")][0]
        qc_path = os.path.join("/content/bioimageio_pretrained_model", qc_path)
      except Exception:
        qc_path = "no-qc"
# --------------------- Check the model exist ------------------------
# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, 
  if not os.path.exists(h5_file_path):
    print(R+'WARNING: pretrained model does not exist')
    Use_pretrained_model = False
    
# If the model path contains a pretrain model, we load the training rate, 
  if os.path.exists(h5_file_path):
#Here we check if the learning rate can be loaded from the quality control folder
    if os.path.exists(qc_path):

      with open(qc_path,'r') as csvfile:
        csvRead = pd.read_csv(csvfile, sep=',')
        #print(csvRead)
    
        if "learning rate" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)
          print("pretrained network learning rate found")
          #find the last learning rate
          lastLearningRate = csvRead["learning rate"].iloc[-1]
          #Find the learning rate corresponding to the lowest validation loss
          min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]
          #print(min_val_loss)
          bestLearningRate = min_val_loss['learning rate'].iloc[-1]

          if Weights_choice == "last":
            print('Last learning rate: '+str(lastLearningRate))

          if Weights_choice == "best":
            print('Learning rate of best validation loss: '+str(bestLearningRate))

        if not "learning rate" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead
          bestLearningRate = initial_learning_rate
          lastLearningRate = initial_learning_rate
          print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)

#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used
    if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):
      print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)
      bestLearningRate = initial_learning_rate
      lastLearningRate = initial_learning_rate


# Display info about the pretrained model to be loaded (or not)
if Use_pretrained_model:
  print('Weights found in:')
  print(h5_file_path)
  print('will be loaded prior to training.')
else:
  print('No pretrained network will be used.')



In [None]:
#@markdown ### **4.3 Prepare network**
#@markdown Creation of the network and configuration of the optimizer
warnings.filterwarnings('ignore')
#print('Training on patches of size (x,y): ('+str(patch_width)+','+str(patch_height)+')')

if number_of_steps == 0:
  number_of_steps = ceil((100-percentage_validation)/100*number_of_training_dataset/batch_size)
#print('Number of steps: '+str(number_of_steps))
(train_datagen, validation_datagen) = prepareGenerators(Patch_source, Patch_target, data_gen_args, batch_size, target_size = (patch_width, patch_height))
# Calculate the number of steps to use for validation
validation_steps = max(1, ceil(percentage_validation/100*number_of_training_dataset/batch_size))

# This modelcheckpoint will only save the best model from the validation loss point of view
model_checkpoint = ModelCheckpoint(os.path.join(full_model_path, 'weights_best.hdf5'), monitor='val_loss',verbose=1, save_best_only=True)

#print('Getting class weights...')
class_weights = getClassWeights(Training_target)

#Print_network_model = True #@param {type:"boolean"}
#Draw_network_model = True #@param {type:"boolean"}
#Save_PDF_network_model = True #@param {type:"boolean"}

# --------------------- Using pretrained model ------------------------
#Here we ensure that the learning rate set correctly when using pre-trained models
if Use_pretrained_model:
  if Weights_choice == "last":
    initial_learning_rate = lastLearningRate

  if Weights_choice == "best":            
    initial_learning_rate = bestLearningRate
else:
  h5_file_path = None

# Reduce learning rate on plateau
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, verbose=1, mode='auto', patience=10, min_lr=0)

# Define the model
model = unet(input_size = (patch_width,patch_height,1), 
             nchannels = nchannels, pooling_steps = pooling_steps, 
             learning_rate = initial_learning_rate, class_weights = class_weights)

# --------------------- Using pretrained model ------------------------
# Load the pretrained weights 
if Use_pretrained_model:
  try:
      model.load_weights(h5_file_path)
  except:
      print("The pretrained model could not be loaded as the configuration of the network is different.")
 
config_model= model.optimizer.get_config()

print()
print("Optimizer configuration : " )
print(config_model)
print()

# ------------------ Failsafes ------------------
if os.path.exists(full_model_path):
  print('!! WARNING: Model folder already existed and has been removed !!')
  print()
  shutil.rmtree(full_model_path)

os.makedirs(full_model_path)
os.makedirs(os.path.join(full_model_path,'Quality Control'))

# ------------------ Display ------------------
print('---------------------------- Main training parameters ----------------------------')
print('Number of epochs: '+str(number_of_epochs))
print('Batch size: '+str(batch_size))
print('Number of training dataset: '+str(number_of_training_dataset))
print('Number of training steps: '+str(number_of_steps))
print('Number of validation steps: '+str(validation_steps))
print('---------------------------- ------------------------ ----------------------------')

def pdf_export(trained = False, augmentation = False, pretrained_model = False):
  class MyFPDF(FPDF, HTMLMixin):
    pass

  pdf = MyFPDF()
  pdf.add_page()
  pdf.set_right_margin(-1)
  pdf.set_font("Arial", size = 11, style='B') 

  day = datetime.now()
  datetime_str = str(day)[0:10]

  Header = 'Training report for '+Network+' model ('+model_name+')\nDate: '+datetime_str
  pdf.multi_cell(180, 5, txt = Header, align = 'L') 
    
  # add another cell 
  if trained:
    training_time = "Training time: "+str(hour)+ "hour(s) "+str(mins)+"min(s) "+str(round(sec))+"sec(s)"
    pdf.cell(190, 5, txt = training_time, ln = 1, align='L')
  pdf.ln(1)

  Header_2 = 'Information for your materials and method:'
  pdf.cell(190, 5, txt=Header_2, ln=1, align='L')

  all_packages = ''
  for requirement in freeze(local_only=True):
    all_packages = all_packages+requirement+', '
  #print(all_packages)

  #Main Packages
  main_packages = ''
  version_numbers = []
  for name in ['tensorflow','numpy','Keras']:
    find_name=all_packages.find(name)
    main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '
    #Version numbers only here:
    version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])

  cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)
  cuda_version = cuda_version.stdout.decode('utf-8')
  cuda_version = cuda_version[cuda_version.find(', V')+3:-1]
  gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)
  gpu_name = gpu_name.stdout.decode('utf-8')
  gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]
  #print(cuda_version[cuda_version.find(', V')+3:-1])
  #print(gpu_name)
  loss = str(model.loss)[str(model.loss).find('function')+len('function'):str(model.loss).find('.<')]
  shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape
  dataset_size = len(os.listdir(Training_source))

  text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(number_of_training_dataset)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_width)+','+str(patch_height)+')) with a batch size of '+str(batch_size)+' and a'+loss+' loss function,'+' using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'

  if pretrained_model:
    text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(number_of_training_dataset)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_width)+','+str(patch_height)+')) with a batch size of '+str(batch_size)+'  and a'+loss+' loss function,'+' using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'

  pdf.set_font('')
  pdf.set_font_size(10.)
  pdf.multi_cell(180, 5, txt = text, align='L')
  pdf.set_font('')
  pdf.set_font('Arial', size = 10, style = 'B')
  pdf.ln(1)
  pdf.cell(28, 5, txt='Augmentation: ', ln=1)
  pdf.set_font('')
  if augmentation:
    aug_text = 'The dataset was augmented by'
    if rotation_range != 0:
      aug_text = aug_text+'\n- rotation'
    if horizontal_flip == True or vertical_flip == True:
      aug_text = aug_text+'\n- flipping'
    if zoom_range != 0:
      aug_text = aug_text+'\n- random zoom magnification'
    if horizontal_shift != 0 or vertical_shift != 0:
      aug_text = aug_text+'\n- shifting'
    if shear_range != 0:
      aug_text = aug_text+'\n- image shearing'
  else:
    aug_text = 'No augmentation was used for training.'
  pdf.multi_cell(190, 5, txt=aug_text, align='L')
  pdf.set_font('Arial', size = 11, style = 'B')
  pdf.ln(1)
  pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)
  pdf.set_font('')
  pdf.set_font_size(10.)
  pdf.cell(200, 5, txt='The following parameters were used for training:')
  pdf.ln(1)
  html = """ 
  <table width=40% style="margin-left:0px;">
    <tr>
      <th width = 50% align="left">Parameter</th>
      <th width = 50% align="left">Value</th>
    </tr>
    <tr>
      <td width = 50%>number_of_epochs</td>
      <td width = 50%>{0}</td>
    </tr>
    <tr>
      <td width = 50%>image_size</td>
      <td width = 50%>{1}</td>
    </tr>
    <tr>
      <td width = 50%>batch_size</td>
      <td width = 50%>{2}</td>
    </tr>
    <tr>
      <td width = 50%>number_of_steps</td>
      <td width = 50%>{3}</td>
    </tr>
    <tr>
      <td width = 50%>percentage_validation</td>
      <td width = 50%>{4}</td>
    </tr>
    <tr>
      <td width = 50%>initial_learning_rate</td>
      <td width = 50%>{5}</td>
    </tr>
    <tr>
      <td width = 50%>pooling_steps</td>
      <td width = 50%>{6}</td>
    </tr>
    <tr>
      <td width = 50%>nchannels</td>
      <td width = 50%>{7}</td>
  </table>
  """.format(number_of_epochs, str(patch_width)+'x'+str(patch_height), batch_size, number_of_steps, percentage_validation, initial_learning_rate, pooling_steps, 
nchannels)
  pdf.write_html(html)

  #pdf.multi_cell(190, 5, txt = text_2, align='L')
  pdf.set_font("Arial", size = 11, style='B')
  pdf.ln(1)
  pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)
  pdf.set_font('')
  pdf.set_font('Arial', size = 10, style = 'B')
  pdf.cell(29, 5, txt= 'Training_source:', align = 'L', ln=0)
  pdf.set_font('')
  pdf.multi_cell(170, 5, txt = Training_source, align = 'L')
  pdf.set_font('')
  pdf.set_font('Arial', size = 10, style = 'B')
  pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)
  pdf.set_font('')
  pdf.multi_cell(170, 5, txt = Training_target, align = 'L')
  #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)
  pdf.ln(1)
  pdf.set_font('')
  pdf.set_font('Arial', size = 10, style = 'B')
  pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)
  pdf.set_font('')
  pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')
  pdf.ln(1)
  pdf.cell(60, 5, txt = 'Example Training pair', ln=1)
  pdf.ln(1)
  #exp_size = io.imread('/content/TrainingDataExample_Unet2D.png').shape
  #pdf.image('/content/TrainingDataExample_Unet2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))
  pdf.ln(1)
  ref_1 = 'References:\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. "Democratising deep learning for microscopy with ZeroCostDL4Mic." Nature Communications (2021).'
  pdf.multi_cell(190, 5, txt = ref_1, align='L')
  ref_2 = '- Unet: Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image segmentation." International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.'
  pdf.multi_cell(190, 5, txt = ref_2, align='L')
  # if Use_Data_augmentation:
  #   ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. "Augmentor: an image augmentation library for machine learning." arXiv preprint arXiv:1708.04680 (2017).'
  #   pdf.multi_cell(190, 5, txt = ref_3, align='L')
  pdf.ln(3)
  reminder = 'Important:\nRemember to perform the quality control step on all newly trained models\nPlease consider depositing your training dataset on Zenodo'
  pdf.set_font('Arial', size = 11, style='B')
  pdf.multi_cell(190, 5, txt=reminder, align='C')

  pdf.output(model_path+'/'+model_name+'/'+model_name+'_model_configuration.pdf')
  print('PDF report exported in '+model_path+'/'+model_name+'/')

#if Print_network_model:
model.summary()

#if Draw_network_model:
#i  tf.keras.utils.plot_model(model, to_file='model_plot.png', show_shapes=True, show_layer_names=True)


#if Save_PDF_network_model:
 # pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)


# **5. Training**

In [None]:
#@markdown ### **5.1 Start training**

#log_dir = "logs/fit/log"
#tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
start = time.time()
tracker = EmissionsTracker(log_level='error')
tracker.start()

history = model.fit_generator(train_datagen, steps_per_epoch = number_of_steps, \
  epochs = number_of_epochs, 
  callbacks=[model_checkpoint, reduce_lr], \
  validation_data = validation_datagen, 
  validation_steps = validation_steps, 
  shuffle=True, verbose=1)

emissions: float = tracker.stop()
print()
print("------------------------------------------")
print(f"[CodeCarbon] Emissions: {emissions} kg")

model.save(os.path.join(full_model_path, 'weights_last.hdf5'))

# convert the history.history dict to a pandas DataFrame:     
lossData = pd.DataFrame(history.history) 
#print(lossData)
#data_table.DataTable(lossData, num_rows_per_page=25)

# The training evaluation.csv is saved (overwrites the Files if needed). 
lossDataCSVpath = os.path.join(full_model_path,'Quality Control/training_evaluation.csv')
with open(lossDataCSVpath, 'w') as f:
  writer = csv.writer(f)
  writer.writerow(['loss','val_loss', 'learning rate'])
  for i in range(len(history.history['loss'])):
    writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])

# Displaying the time elapsed for training
dt = time.time() - start
mins, sec = divmod(dt, 60) 
hour, mins = divmod(mins, 60) 
print("------------------------------------------")
#print("Time elapsed:", dt)
print("Training time:", hour, "hour(s)", mins,"min(s)",round(sec),"sec(s)")

print("------------------------------------------")
print('Last loss: ', history.history['loss'][-1],'Last val_loss:', history.history['val_loss'][-1])
print('Best loss: ', min(history.history['loss']),'Best val_loss:', min(history.history['val_loss']))

print("------------------------------------------")
pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)


In [None]:
#@markdown ### **5.2 Plot the learning curves**
model_weights = 'weights_best.hdf5'
model_save = os.path.join(model_path, model_name)
model_file = os.path.join(model_save, model_weights)

#print('model_save:', model_save)
#print('model_file:', model_file)
if os.path.exists(model_file):
  print("The "+model_weights+" network is evaluated")
else:
  print(R+'!! WARNING: The chosen model '+model_file+'does not exist !!'+W)
  print('Please make sure you provide a valid model path and model name before proceeding further.')

epochNumber = []
lossDataFromCSV = []
vallossDataFromCSV = []

with open(os.path.join(model_save, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:
    csvRead = csv.reader(csvfile, delimiter=',')
    next(csvRead)
    for row in csvRead:
        lossDataFromCSV.append(float(row[0]))
        vallossDataFromCSV.append(float(row[1]))

epochNumber = range(len(lossDataFromCSV))

plt.figure(figsize=(15,10))

plt.subplot(2,1,1)
plt.plot(epochNumber,lossDataFromCSV, label='Training loss')
plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')
plt.title('Training loss and validation loss vs. epoch number (linear scale)')
plt.ylabel('Loss')
plt.xlabel('Epoch number')
plt.legend()

plt.subplot(2,1,2)
plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')
plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')
plt.title('Training loss and validation loss vs. epoch number (log scale)')
plt.ylabel('Loss')
plt.xlabel('Epoch number')
plt.legend()
plt.savefig(os.path.join(model_save, 'Quality Control', 'lossCurvePlots.png'),bbox_inches='tight',pad_inches=0)
plt.show()

# **6. Evaluate the trained model**

In [None]:

#@markdown  ### **6.1 Prediction on the test dataset**

#@markdown This section generates the probability maps of all images provided in the test_source folder

# #@markdown **`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.

# #@markdown **`Probability_folder`:** This folder will contain the probabity map (float32) as image.

#Data_folder = '/content/gdrive/MyDrive/test Goteborg/ctc-DIC-C2DH-HeLa/test_source' #@param {type:"string"}
#Probability_folder = '/content/gdrive/MyDrive/test Goteborg/Results' #@param {type:"string"}

warnings.filterwarnings('ignore')

Data_folder = Source_QC
Probability_folder = os.path.join(model_save,'Prediction')
if os.path.exists(Probability_folder):
  shutil.rmtree(Probability_folder)
os.makedirs(Probability_folder)


full_Probability_model_path = os.path.join(model_save, model_weights)
if os.path.exists(full_Probability_model_path):
  print("The model "+full_Probability_model_path+" will be used.")
else:
  print(R+'!! WARNING: The chosen model does not exist !!'+W)
  print('Please make sure you provide a valid model path and model name before proceeding further.')

print()
# Load the model and prepare generator
model_save = full_model_path
model_weights =  "weights_best.hdf5"
#print('The model ' + os.path.join(model_save, model_weights) + ' will be used')

unet = load_model(os.path.join(model_save, model_weights), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})
Input_size = unet.layers[0].output_shape[0]
#print('Model input size: '+str(Input_size[0])+'x'+str(Input_size[1]))

# Create a list of sources
source_dir_list = os.listdir(Data_folder)
number_of_dataset = len(source_dir_list)
print('Number of dataset found in the folder: '+str(number_of_dataset))

for i in tqdm(range(number_of_dataset)):
  root_filename = os.path.splitext(source_dir_list[i])
  prediction = predict_as_tiles(os.path.join(Data_folder, source_dir_list[i]), unet)
  # predictions.append(prediction(os.path.join(Data_folder, source_dir_list[i]), os.path.join(Prediction_model_path, Prediction_model_name)))
  io.imsave(os.path.join(Probability_folder, 'proba_'+root_filename[0]+'.tif'), img_as_float32(prediction)) # saving as unsigned 8-bit image
      


---

# Exercise session

1.   Do a new training with a different number of epochs (between 10 and 50)
*ATTENTION: You only need to recompute cell 1.2 and the cells from 4.1 to 6.1 with the good parameters*

2.   With ImageJ, compute the **IoU** score for **3** predictions:
    *   Do a **threshold** to obtain a binary mask
    *   Perform a **watershed** to well separate two instances of cells
    *   Compute the **IoU** score:
$$\frac{Target \ \cap \ Source}{Target \ \cup \ Prediction}$$

    *   Save the results on the excel file that is on the google drive of the course

3.   To better visualize the quality of the obtained segmentation, with ImageJ, do a composite image to superimpose the binary masks on the corresponding test images.

4.   If you have time, you can do it for different number of training epochs


---

In [None]:
# ------------- User input ------------
#@markdown ### **6.2 Quality Control** 
print('The model ' + os.path.join(model_save, model_weights) + ' will be used')
#@markdown This section calculates the **IoU** score for all the images provided in the test_source and test_target. 
#@markdown During this step, a threshold of 0.5 is applied on each probability map.

warnings.filterwarnings('ignore')

# ------------- Initialise folders ------------
# Create a quality control/Prediction Folder
#prediction_QC = os.path.join(model_save, 'Quality Control', 'Prediction')
#if os.path.exists(prediction_QC):
  #shutil.rmtree(prediction_QC)
#os.makedirs(prediction_QC)

# ------------- Prepare the model and run predictions ------------

# Load the model
unet = load_model( os.path.join(model_save, model_weights), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})
Input_size = unet.input_shape[1:3]
#print('Model input size: '+str(Input_size[0])+'x'+str(Input_size[1]))

# Create a list of sources
source_dir_list = os.listdir(Source_QC)
number_of_dataset = len(source_dir_list)
print('Number of images found in the folder: '+str(number_of_dataset))


# Compute the metrics
intersection_list = []
union_list = []
iou_list = []
print()
print('Filename, Intersection, Union, IoU')

for i in tqdm(range(number_of_dataset)):
  filename = source_dir_list[i]
  
  if filename.startswith('.'):
    print('Rejected ', file)
  elif filename.endswith('.tif') or filename.endswith('.tiff') or filename.endswith('.png'):
    prediction_image = predict_as_tiles(os.path.join(Source_QC, filename), unet)
    test_target_image = io.imread(os.path.join(Target_QC, filename), as_gray=True)

    #io.imsave(os.path.join(prediction_QC, 'prediction_' + filename), prediction_image)
    
    mask_target = img_as_ubyte(test_target_image, force_copy=True)
    mask_prediction = img_as_ubyte(prediction_image, force_copy=True)

    mask_target[mask_target > 0.5] = 255
    mask_target[mask_target <= 0.5] = 0

    mask_prediction[mask_prediction > 128] = 255
    mask_prediction[mask_prediction <= 128] = 0
    
    mask_target = mask_target / 255
    mask_prediction = mask_prediction / 255

    n_target = int(np.sum(mask_target))
    n_predic = int(np.sum(mask_prediction))
    intersection = np.sum(np.logical_and(mask_target, mask_prediction))
    union = np.sum(np.logical_or(mask_target, mask_prediction))
    iou_score = intersection / union

    intersection_list.append(intersection)
    union_list.append(union)
    iou_list.append(iou_score)
    
    print(filename, intersection, union, round(iou_score,5))

average_intersection = sum(intersection_list)/number_of_dataset
average_union = sum(union_list)/number_of_dataset
average_iou = sum(iou_list)/number_of_dataset

print('--------------------------------------------------------------')
print('Average', round(average_intersection,3), round(average_union,3), round(average_iou,5))
print('--------------------------------------------------------------')


# **7. Save your model**


In [None]:
#@markdown ### **7.1. Export the model as BMZ format**

#@markdown This section exports the model into the BioImage Model Zoo (BMZ) format so it can be used directly with DeepImageJ. 

warnings.filterwarnings('ignore')

# information about the model
Trained_model_name    = model_name
Trained_model_authors =  "" #@param {type:"string"}
Trained_model_authors_affiliation =  "" #@param {type:"string"}
Trained_model_description = "Segmentation_UNet_Gothenburg2022" 
Trained_model_license = 'MIT'
Trained_model_references = ["O. Ronneberger *et al. MICCAI 2015", 
                            "L. von Chamier et al., Nature Methods, 2021", 
                            "W. Ouyang et al., biorxiv  2022",
                            "E. Gómez-de-Mariscal et al., Nature Methods 202"] 
Trained_model_DOI = ["https://arxiv.org/abs/1505.04597",
                     "https://www.nature.com/articles/s41467-021-22518-0",
                     "https://www.biorxiv.org/content/10.1101/2022.06.07.495102v1",
                     "https://doi.org/10.1038/s41592-021-01262-9"] 

print('Trained_model_name: ', Trained_model_name)
print('Trained_model_authors: ', Trained_model_authors)
print('Trained_model_authors_affiliation: ', Trained_model_authors_affiliation)
print('Trained_model_description: ', Trained_model_description)
print('Trained_model_license: ', Trained_model_license)
print('Trained_model_references: ', Trained_model_references)
print('Trained_model_DOI: ', Trained_model_DOI)

# Training data
include_training_data = False
data_from_bioimage_model_zoo = False
training_data_ID = ''
training_data_source = ''
training_data_description = ''
apply_threshold = True
Use_The_Best_Average_Threshold = True
average_best_threshold = 128
threshold = average_best_threshold
PixelSize = 1
#print('Mask_threshold:' threshold)
#print('PixelSize:' PixelSize)
default_example_image = True
#fileID    =  "/content/gdrive/MyDrive/Projectos/DEEP-IMAGEJ/examples_of_models/ZeroCostDL4Mic/data4UNet/zhixuhao/test_images/image28.tif" #@param {type:"string"}

QC_model_folder = model_save

if Use_The_Best_Average_Threshold:
    threshold = average_best_threshold
threshold /= 255.0

if default_example_image:
    source_dir_list = os.listdir(Source_QC)
    fileID = os.path.join(Source_QC, source_dir_list[0])

print()
print("------------------------------------------")
print('PARAMETERS')
print('threshold: ', threshold)
print('PixelSize: ', PixelSize)
print('Test image: ', fileID)

# load the model
compiled_weight_path = os.path.join(model_save, model_weights)
print('Weights: ', compiled_weight_path)
unet = load_model(compiled_weight_path, custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})
print()

# remove the custom loss function from the model, so that it can be used outside of this notebook
unet = Model(unet.input, unet.output)
weight_path = os.path.join(model_save, 'keras_weights.hdf5')
unet.save(weight_path)

# create the author spec input
auth_names = Trained_model_authors[1:-1].split(",")
auth_affs = Trained_model_authors_affiliation[1:-1].split(",")
assert len(auth_names) == len(auth_affs)
authors = [{"name": auth_name, "affiliation": auth_aff} for auth_name, auth_aff in zip(auth_names, auth_affs)]
license = Trained_model_license

# where to save the model
output_root = os.path.join(model_path, Trained_model_name + '.bioimage.io.model')
os.makedirs(output_root, exist_ok=True)
output_path = os.path.join(output_root, f"{Trained_model_name}.zip")

# create a markdown readme with information
readme_path = os.path.join(output_root, "README.md")
with open(readme_path, "w") as f:
  f.write("Visit https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki")

# create the citation input spec
assert len(Trained_model_DOI) == len(Trained_model_references)
citations = [{'text': text, 'doi': doi} for text, doi in zip(Trained_model_references, Trained_model_DOI)]

# create the training data
if include_training_data:
    if data_from_bioimage_model_zoo:
      training_data = {"id": training_data_ID}
    else:
      training_data = {"source": training_data_source,
                       "description": training_data_description}
else:
    training_data={}

# create the input spec
min_percentile = 1
max_percentile = 99.8
shape = [sh.value for sh in unet.input.shape]

# batch should never be constrained
assert shape[0] is None
shape[0] = 1  # batch is set to 1 for bioimagei.io
assert all(sh is not None for sh in shape)  # make sure all other shapes are fixed
pixel_size = {"x": PixelSize, "y": PixelSize}
kwargs = dict(
  input_names=["input"],
  input_axes=["bxyc"],
  pixel_sizes=[pixel_size],
  preprocessing=[[{"name": "scale_range", "kwargs": {"min_percentile": min_percentile, "max_percentile": max_percentile, "mode": "per_sample","axes": "xyc"}}]]
)
shape = tuple(shape)

print()
print("------------------------------------------")
print('INFORMATION')

if apply_threshold:
  print("The model output is thresholded")
  postprocessing = [[{"name": "binarize", "kwargs": {"threshold": threshold}}]]
else:
  print("The model output is not thresholded")
  postprocessing = None

output_spec = dict(
  output_names=["output"],
  output_axes=["bxyc"],
  postprocessing=postprocessing
)
kwargs.update(output_spec)

# load the input image, crop it if necessary and save as numpy file
test_img = io.imread(fileID)
assert test_img.ndim == 2
test_img = test_img[:shape[1], :shape[2]]
assert test_img.shape == shape[1:3], f"{test_img.shape}, {shape}"

# Save the test image
test_in_path = os.path.join(output_root, "test_input.npy")

np.save(test_in_path, test_img[None, ..., None])  # add batch and channel axis
# Normalize the image before adding batch and channel dimensions
test_img = normalizePercentile(test_img.astype("float32"))
test_img = test_img[None, ..., None] 
test_prediction = unet.predict(test_img)

# run prediction on the input image and save the result as expected output
if apply_threshold:
  test_prediction = np.squeeze(test_prediction) > threshold
  test_prediction = test_prediction.astype(np.uint8)
else:
  test_prediction = np.squeeze(test_prediction)
assert test_prediction.ndim == 2
test_prediction = test_prediction[None, ..., None]  

# add batch and channel axis
test_out_path = os.path.join(output_root, "test_output.npy")
np.save(test_out_path, test_prediction)

# attach the QC report to the model (if it exists)
qc_path = os.path.join(model_save, 'Quality Control', 'training_evaluation-' + model_name + '.csv')
if os.path.exists(qc_path):
  attachments = {"files": [qc_path]}
else:
  attachments = None

# export the model with keras weihgts
build_model(
    weight_uri=weight_path,
    test_inputs=[test_in_path],
    test_outputs=[test_out_path],
    name=Trained_model_name,
    description=Trained_model_description,
    authors=authors,
    tags=['zerocostdl4mic', 'deepimagej', 'segmentation', 'tem', 'unet'],
    license=license,
    documentation=readme_path,
    cite=citations,
    output_path=output_path,
    add_deepimagej_config=True,
    tensorflow_version=tf.__version__,
    attachments=attachments,
    training_data = training_data,
    **kwargs
)


# convert the keras weights to tensorflow and add them to the model
tf_weight_path = os.path.join(QC_model_folder, "tf_weights")

# we need to make sure that the tf weight folder does not exist
if os.path.exists(tf_weight_path):
  rmtree(tf_weight_path)

convert_weights_to_tensorflow_saved_model_bundle(output_path, tf_weight_path + ".zip")
add_weights(output_path, tf_weight_path + ".zip", output_path, tensorflow_version=tf.__version__)

# check that the model works for keras and tensorflow 
res = test_model(output_path, weight_format="keras_hdf5")
success = True
if res["error"] is not None:
  success = False
  print("test-model failed for keras weights:", res["error"])
res = test_model(output_path, weight_format="tensorflow_saved_model_bundle")
if res["error"] is not None:
  success = False
  print("test-model failed for tensorflow weights:", res["error"])
if success:
  print("The bioimage.io model was successfully exported to", output_path)
else:
  print("The bioimage.io model was exported to", output_path)
  print("Some tests of the model did not work! You can still download and test the model.")
  print("You can still download and test the model, but it may not work as expected.")