# **StarDist (3D)**
---

<font size = 4>**StarDist 3D** is a deep-learning method that can be used to segment cell nuclei from 3D bioimages and was first published by [Weigert *et al.* in 2019 on arXiv](https://arxiv.org/abs/1908.03636), extending to 3D the 2D appraoch from [Schmidt *et al.* in 2018](https://arxiv.org/abs/1806.03535). It uses a shape representation based on star-convex polygons for nuclei in an image to predict the presence and the shape of these nuclei. This StarDist 3D network is based on an adapted ResNet network architecture.

<font size = 4> **This particular notebook enables nuclei segmentation of 3D dataset. If you are interested in 2D dataset, you should use the StarDist 2D notebook instead.**

---
<font size = 4>*Disclaimer*:

<font size = 4>This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.

<font size = 4>This notebook is largely based on the paper:

<font size = 4>**Cell Detection with Star-convex Polygons** from Schmidt *et al.*, International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI), Granada, Spain, September 2018. (https://arxiv.org/abs/1806.03535)

<font size = 4>and the 3D extension of the approach:

<font size = 4>**Star-convex Polyhedra for 3D Object Detection and Segmentation in Microscopy** from Weigert *et al.* published on arXiv in 2019 (https://arxiv.org/abs/1908.03636)

<font size = 4>**The Original code** is freely available in GitHub:
https://github.com/mpicbg-csbd/stardist

<font size = 4>**Please also cite this original paper when using or developing this notebook.**



# **How to use this notebook?**

---

<font size = 4>Video describing how to use our notebooks are available on youtube:
  - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook
  - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook


---
### **Structure of a notebook**

<font size = 4>The notebook contains two types of cell:  

<font size = 4>**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.

<font size = 4>**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.

---
### **Table of contents, Code snippets** and **Files**

<font size = 4>On the top left side of the notebook you find three tabs which contain from top to bottom:

<font size = 4>*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.

<font size = 4>*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.

<font size = 4>*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. 

<font size = 4>**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.

<font size = 4>**Note:** The "sample data" in "Files" contains default files. Do not upload anything in here!

---
### **Making changes to the notebook**

<font size = 4>**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.

<font size = 4>To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).
You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment.


# **0. Before getting started**
---
<font size = 4> For StarDist to train, **it needs to have access to a paired training dataset made of images of nuclei and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki

<font size = 4>**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.

<font size = 4>The data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called "Training - Images" (Training_source) and "Training - Masks" (Training_target).

<font size = 4>Additionally, the corresponding Training_source and Training_target files need to have **the same name**.

<font size = 4>Please note that you currently can **only use .tif files!**

<font size = 4>You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed.

<font size = 4>Here's a common data structure that can work:
*   Experiment A
    - **Training dataset**
      - Images of nuclei (Training_source)
        - img_1.tif, img_2.tif, ...
      - Masks (Training_target)
        - img_1.tif, img_2.tif, ...
    - **Quality control dataset**
     - Images of nuclei
        - img_1.tif, img_2.tif
      - **Masks** 
        - img_1.tif, img_2.tif
    - **Data to be predicted**
    - **Results**

---
<font size = 4>**Important note**

<font size = 4>- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.

<font size = 4>- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.

<font size = 4>- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.
---


# **1. Install StarDist and dependencies**
---



## **1.1. Load key dependencies**
---
<font size = 4> 


In [None]:
# Run this cell to execute the code
from __future__ import print_function, unicode_literals, absolute_import, division
from datetime import datetime
import ipywidgets as widgets
from IPython.display import Markdown, display, clear_output
from matplotlib import pyplot as plt
import yaml as yaml_library
import os

ipywidgets_edit_yaml_config_path = os.path.join(
    os.getcwd(), "results", "widget_prev_settings.yaml"
)


def ipywidgets_edit_yaml(yaml_path, key, value):
    if os.path.exists(yaml_path):
        with open(yaml_path, "r") as f:
            config_data = yaml_library.safe_load(f)
    else:
        config_data = {}
    config_data[key] = value
    with open(yaml_path, "w") as new_f:
        yaml_library.safe_dump(
            config_data,
            new_f,
            width=10e10,
            default_flow_style=False,
            allow_unicode=True,
        )


def ipywidgets_read_yaml(yaml_path, key):
    if os.path.exists(yaml_path):
        with open(yaml_path, "r") as f:
            config_data = yaml_library.safe_load(f)
        value = config_data.get(key, "")
        return value
    else:
        return ""


internal_aux_initial_time = datetime.now()
print("Runnning...")
print("--------------------------------------")
Notebook_version = "1.15.3"
Network = "StarDist 3D"

from builtins import any as b_any


def get_requirements_path():
    # Store requirements file in 'contents' directory
    current_dir = os.getcwd()
    dir_count = current_dir.count("/") - 1
    path = "../" * (dir_count) + "requirements.txt"
    return path


def filter_files(file_list, filter_list):
    filtered_list = []
    for fname in file_list:
        if b_any(fname.split("==")[0] in s for s in filter_list):
            filtered_list.append(fname)
    return filtered_list


def build_requirements_file(before, after):
    path = get_requirements_path()

    # Exporting requirements.txt for local run
    !pip freeze > $path

    # Get minimum requirements file
    df = pd.read_csv(path)
    mod_list = [m.split(".")[0] for m in after if not m in before]
    req_list_temp = df.values.tolist()
    req_list = [x[0] for x in req_list_temp]

    # Replace with package name and handle cases where import name is different to module name
    mod_name_list = [["sklearn", "scikit-learn"], ["skimage", "scikit-image"]]
    mod_replace_list = [
        [x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s
        for s in mod_list
    ]
    filtered_list = filter_files(req_list, mod_replace_list)

    file = open(path, "w")
    for item in filtered_list:
        file.writelines(item)

    file.close()


import sys

before = [str(m) for m in sys.modules]

# %tensorflow_version 1.x
import tensorflow

print(tensorflow.__version__)
print("Tensorflow enabled.")

# ------- Variable specific to Stardist -------
from stardist import (
    fill_label_holes,
    random_label_cmap,
    calculate_extents,
    gputools_available,
)
from stardist.models import Config3D, StarDist3D, StarDistData3D
from stardist import relabel_image_stardist3D, Rays_GoldenSpiral, calculate_extents
from stardist.matching import matching_dataset
from csbdeep.utils import (
    Path,
    normalize,
    download_and_extract_zip_file,
    plot_history,
)  # for loss plot
from csbdeep.io import save_tiff_imagej_compatible
import numpy as np

np.random.seed(42)
lbl_cmap = random_label_cmap()
import cv2

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# ------- Common variable to all ZeroCostDL4Mic notebooks -------
import numpy as np
from matplotlib import pyplot as plt
import urllib
import os, random
import shutil
import zipfile
from tifffile import imread, imsave
from csbdeep.io import save_tiff_imagej_compatible
import time
import sys
import wget
from pathlib import Path
import pandas as pd
import csv
from glob import glob
from scipy import signal
from scipy import ndimage
from skimage import io
from sklearn.linear_model import LinearRegression
from skimage.util import img_as_uint
import matplotlib as mpl
from skimage.metrics import structural_similarity
from skimage.metrics import peak_signal_noise_ratio as psnr
from astropy.visualization import simple_norm
from skimage import img_as_float32
from skimage.util import img_as_ubyte
from tqdm import tqdm
from fpdf import FPDF, HTMLMixin
from datetime import datetime
from pip._internal.operations.freeze import freeze
import subprocess
from astropy.visualization import simple_norm

# For sliders and dropdown menu and progress bar
from ipywidgets import interact
import ipywidgets as widgets

# Create a variable to get and store relative base path
base_path = os.getcwd()


# Colors for the warning messages
class bcolors:
    WARNING = "\033[31m"


W = "\033[0m"  # white (normal)
R = "\033[31m"  # red

# Disable some of the tensorflow warnings
import warnings

warnings.filterwarnings("ignore")

print("Libraries installed")

# Check if this is the latest version of the notebook
All_notebook_versions = pd.read_csv(
    "https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv",
    dtype=str,
)
print("Notebook version: " + Notebook_version)
Latest_Notebook_version = All_notebook_versions[
    All_notebook_versions["Notebook"] == Network
]["Version"].iloc[0]
print("Latest notebook version: " + Latest_Notebook_version)
if Notebook_version == Latest_Notebook_version:
    print("This notebook is up-to-date.")
else:
    print(
        bcolors.WARNING
        + "A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki"
    )


def pdf_export(trained=False, augmentation=False, pretrained_model=False):
    class MyFPDF(FPDF, HTMLMixin):
        pass

    pdf = MyFPDF()
    pdf.add_page()
    pdf.set_right_margin(-1)
    pdf.set_font("Arial", size=11, style="B")

    day = datetime.now()
    datetime_str = str(day)[0:10]

    Header = (
        "Training report for "
        + Network
        + " model ("
        + model_name
        + ")\nDate: "
        + datetime_str
    )
    pdf.multi_cell(180, 5, txt=Header, align="L")
    pdf.ln(1)

    # add another cell
    if trained:
        training_time = (
            "Training time: "
            + str(hour)
            + "hour(s) "
            + str(mins)
            + "min(s) "
            + str(round(sec))
            + "sec(s)"
        )
        pdf.cell(190, 5, txt=training_time, ln=1, align="L")
    pdf.ln(1)

    Header_2 = "Information for your materials and methods:"
    pdf.cell(190, 5, txt=Header_2, ln=1, align="L")

    all_packages = ""
    for requirement in freeze(local_only=True):
        all_packages = all_packages + requirement + ", "
    # print(all_packages)

    # Main Packages
    main_packages = ""
    version_numbers = []
    for name in ["tensorflow", "numpy", "keras", "csbdeep"]:
        find_name = all_packages.find(name)
        main_packages = (
            main_packages
            + all_packages[find_name : all_packages.find(",", find_name)]
            + ", "
        )
        # Version numbers only here:
        version_numbers.append(
            all_packages[find_name + len(name) + 2 : all_packages.find(",", find_name)]
        )

    try:
        cuda_version = subprocess.run(["nvcc", "--version"], stdout=subprocess.PIPE)
        cuda_version = cuda_version.stdout.decode("utf-8")
        cuda_version = cuda_version[cuda_version.find(", V") + 3 : -1]
    except:
        cuda_version = " - No cuda found - "
    try:
        gpu_name = subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE)
        gpu_name = gpu_name.stdout.decode("utf-8")
        gpu_name = gpu_name[gpu_name.find("Tesla") : gpu_name.find("Tesla") + 10]
    except:
        gpu_name = " - No GPU found - "
    # print(cuda_version[cuda_version.find(', V')+3:-1])
    # print(gpu_name)

    shape = io.imread(Training_source + "/" + os.listdir(Training_source)[1]).shape
    dataset_size = len(os.listdir(Training_source))

    text = (
        "The "
        + Network
        + " model was trained from scratch for "
        + str(number_of_epochs)
        + " epochs on "
        + str(int(len(X)))
        + " paired image patches (image dimensions: "
        + str(shape)
        + ", patch size: ("
        + str(patch_height)
        + ","
        + str(patch_size)
        + ","
        + str(patch_size)
        + ")) with a batch size of "
        + str(batch_size)
        + " and a "
        + conf.train_dist_loss
        + " loss function, using the "
        + Network
        + " ZeroCostDL4Mic notebook (v "
        + Notebook_version[0]
        + ") (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v "
        + version_numbers[0]
        + "), keras (v "
        + version_numbers[2]
        + "), csbdeep (v "
        + version_numbers[3]
        + "), numpy (v "
        + version_numbers[1]
        + "), cuda (v "
        + cuda_version
        + "). The training was accelerated using a "
        + gpu_name
        + "GPU."
    )

    # text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'.'

    if pretrained_model:
        text = (
            "The "
            + Network
            + " model was trained for "
            + str(number_of_epochs)
            + " epochs on "
            + str(int(len(X)))
            + " paired image patches (image dimensions: "
            + str(shape)
            + ", patch size: ("
            + str(patch_height)
            + ","
            + str(patch_size)
            + ","
            + str(patch_size)
            + ")) with a batch size of "
            + str(batch_size)
            + " and a "
            + conf.train_dist_loss
            + " loss function, using the "
            + Network
            + " ZeroCostDL4Mic notebook (v "
            + Notebook_version[0]
            + ") (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v "
            + version_numbers[0]
            + "), keras (v "
            + version_numbers[2]
            + "), csbdeep (v "
            + version_numbers[3]
            + "), numpy (v "
            + version_numbers[1]
            + "), cuda (v "
            + cuda_version
            + "). The training was accelerated using a "
            + gpu_name
            + "GPU."
        )

    pdf.set_font("")
    pdf.set_font_size(10.0)
    pdf.multi_cell(190, 5, txt=text, align="L")
    pdf.ln(1)
    pdf.set_font("")
    pdf.set_font("Arial", size=10, style="B")
    pdf.cell(28, 5, txt="Augmentation: ", ln=0)
    pdf.set_font("")
    if augmentation:
        aug_text = "The dataset was augmented."
    else:
        aug_text = "No augmentation was used for training."
    pdf.multi_cell(190, 5, txt=aug_text, align="L")
    pdf.ln(1)
    pdf.set_font("Arial", size=11, style="B")
    pdf.cell(180, 5, txt="Parameters", align="L", ln=1)
    pdf.set_font("")
    pdf.set_font_size(10.0)
    if Use_Default_Advanced_Parameters:
        pdf.cell(200, 5, txt="Default Advanced Parameters were enabled")
    pdf.cell(200, 5, txt="The following parameters were used for training:")
    pdf.ln(1)
    html = """ 
  <table width=40% style="margin-left:0px;">
    <tr>
      <th width = 50% align="left">Parameter</th>
      <th width = 50% align="left">Value</th>
    </tr>
    <tr>
      <td width = 50%>number_of_epochs</td>
      <td width = 50%>{0}</td>
    </tr>
    <tr>
      <td width = 50%>patch_size</td>
      <td width = 50%>{1}</td>
    </tr>
    <tr>
      <td width = 50%>batch_size</td>
      <td width = 50%>{2}</td>
    </tr>
    <tr>
      <td width = 50%>number_of_steps</td>
      <td width = 50%>{3}</td>
    </tr>
    <tr>
      <td width = 50%>percentage_validation</td>
      <td width = 50%>{4}</td>
    </tr>
      <tr>
      <td width = 50%>n_rays</td>
      <td width = 50%>{5}</td>
    </tr>
    <tr>
      <td width = 50%>initial_learning_rate</td>
      <td width = 50%>{6}</td>
    </tr>
  </table>
  """.format(
        number_of_epochs,
        str(patch_height) + "x" + str(patch_size) + "x" + str(patch_size),
        batch_size,
        number_of_steps,
        percentage_validation,
        n_rays,
        initial_learning_rate,
    )
    pdf.write_html(html)

    # pdf.multi_cell(190, 5, txt = text_2, align='L')
    pdf.set_font("Arial", size=11, style="B")
    pdf.ln(1)
    pdf.cell(190, 5, txt="Training Dataset", align="L", ln=1)
    pdf.set_font("")
    pdf.set_font("Arial", size=10, style="B")
    pdf.cell(30, 5, txt="Training_source:", align="L", ln=0)
    pdf.set_font("")
    pdf.multi_cell(170, 5, txt=Training_source, align="L")
    pdf.ln(1)
    pdf.set_font("")
    pdf.set_font("Arial", size=10, style="B")
    pdf.cell(28, 5, txt="Training_target:", align="L", ln=0)
    pdf.set_font("")
    pdf.multi_cell(170, 5, txt=Training_target, align="L")
    pdf.ln(1)
    pdf.set_font("")
    pdf.set_font("Arial", size=10, style="B")
    pdf.cell(21, 5, txt="Model Path:", align="L", ln=0)
    pdf.set_font("")
    pdf.multi_cell(170, 5, txt=model_path + "/" + model_name, align="L")
    pdf.ln(1)
    pdf.cell(60, 5, txt="Example Training pair", ln=1)
    pdf.ln(1)
    exp_size = io.imread(base_path + "/TrainingDataExample_StarDist3D.png").shape
    pdf.image(
        base_path + "/TrainingDataExample_StarDist3D.png",
        x=11,
        y=None,
        w=round(exp_size[1] / 8),
        h=round(exp_size[0] / 8),
    )
    pdf.ln(1)
    ref_1 = 'References:\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. "Democratising deep learning for microscopy with ZeroCostDL4Mic." Nature Communications (2021).'
    pdf.multi_cell(190, 5, txt=ref_1, align="L")
    pdf.ln(1)
    ref_2 = '- StarDist 3D: Schmidt, Uwe, et al. "Cell detection with star-convex polygons." International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, Cham, 2018.'
    pdf.multi_cell(190, 5, txt=ref_2, align="L")
    pdf.ln(1)
    ref_3 = '- StarDist 3D: Weigert, Martin, et al. "Star-convex polyhedra for 3d object detection and segmentation in microscopy." The IEEE Winter Conference on Applications of Computer Vision. 2020.'
    pdf.multi_cell(190, 5, txt=ref_3, align="L")
    # if Use_Data_augmentation:
    #   ref_4 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. "Augmentor: an image augmentation library for machine learning." arXiv preprint arXiv:1708.04680 (2017).'
    #   pdf.multi_cell(190, 5, txt = ref_4, align='L')
    pdf.ln(3)
    reminder = "Important:\nRemember to perform the quality control step on all newly trained models\nPlease consider depositing your training dataset on Zenodo"
    pdf.set_font("Arial", size=11, style="B")
    pdf.multi_cell(190, 5, txt=reminder, align="C")
    pdf.ln(1)

    pdf.output(
        model_path + "/" + model_name + "/" + model_name + "_training_report.pdf"
    )


def qc_pdf_export():
    class MyFPDF(FPDF, HTMLMixin):
        pass

    pdf = MyFPDF()
    pdf.add_page()
    pdf.set_right_margin(-1)
    pdf.set_font("Arial", size=11, style="B")

    Network = "Stardist 3D"

    day = datetime.now()
    datetime_str = str(day)[0:10]

    Header = (
        "Quality Control report for "
        + Network
        + " model ("
        + QC_model_name
        + ")\nDate: "
        + datetime_str
    )
    pdf.multi_cell(180, 5, txt=Header, align="L")
    pdf.ln(1)

    all_packages = ""
    for requirement in freeze(local_only=True):
        all_packages = all_packages + requirement + ", "

    pdf.set_font("")
    pdf.set_font("Arial", size=11, style="B")
    pdf.ln(2)
    pdf.cell(190, 5, txt="Development of Training Losses", ln=1, align="L")
    pdf.ln(1)
    exp_size = io.imread(
        full_QC_model_path + "/Quality Control/lossCurvePlots.png"
    ).shape
    if os.path.exists(full_QC_model_path + "/Quality Control/lossCurvePlots.png"):
        pdf.image(
            full_QC_model_path + "/Quality Control/lossCurvePlots.png",
            x=11,
            y=None,
            w=round(exp_size[1] / 8),
            h=round(exp_size[0] / 8),
        )
    else:
        pdf.set_font("")
        pdf.set_font("Arial", size=10)
        pdf.multi_cell(
            190,
            5,
            txt="If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.",
        )
    pdf.ln(2)
    pdf.set_font("")
    pdf.set_font("Arial", size=10, style="B")
    pdf.ln(3)
    pdf.cell(80, 5, txt="Example Quality Control Visualisation", ln=1)
    pdf.ln(1)
    exp_size = io.imread(
        full_QC_model_path + "/Quality Control/QC_example_data.png"
    ).shape
    pdf.image(
        full_QC_model_path + "/Quality Control/QC_example_data.png",
        x=16,
        y=None,
        w=round(exp_size[1] / 10),
        h=round(exp_size[0] / 10),
    )
    pdf.ln(1)
    pdf.set_font("")
    pdf.set_font("Arial", size=11, style="B")
    pdf.ln(1)
    pdf.cell(180, 5, txt="Quality Control Metrics", align="L", ln=1)
    pdf.set_font("")
    pdf.set_font_size(10.0)

    pdf.ln(1)
    html = """
  <body>
  <font size="10" face="Courier" >
  <table width=50% style="margin-left:0px;">"""
    with open(
        full_QC_model_path
        + "/Quality Control/Quality_Control for "
        + QC_model_name
        + ".csv",
        "r",
    ) as csvfile:
        metrics = csv.reader(csvfile)
        header = next(metrics)
        image = header[0]
        PvGT_IoU = header[1]
        header = """
    <tr>
    <th width = 50% align="center">{0}</th>
    <th width = 50% align="center">{1}</th>
    </tr>""".format(image, PvGT_IoU)
        html = html + header
        for row in metrics:
            image = row[0]
            PvGT_IoU = row[1]
            cells = """
        <tr>
          <td width = 50% align="center">{0}</td>
          <td width = 50% align="center">{1}</td>
        </tr>""".format(image, str(round(float(PvGT_IoU), 3)))
            html = html + cells
        html = html + """</body></table>"""

    pdf.write_html(html)

    pdf.ln(1)
    pdf.set_font("")
    pdf.set_font_size(10.0)
    ref_1 = 'References:\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. "Democratising deep learning for microscopy with ZeroCostDL4Mic." Nature Communications (2021).'
    pdf.multi_cell(190, 5, txt=ref_1, align="L")
    pdf.ln(1)
    ref_2 = ' - Weigert, Martin, et al. "Star-convex polyhedra for 3d object detection and segmentation in microscopy." The IEEE Winter Conference on Applications of Computer Vision. 2020.'
    pdf.multi_cell(190, 5, txt=ref_2, align="L")
    pdf.ln(1)

    pdf.ln(3)
    reminder = "To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name."

    pdf.set_font("Arial", size=11, style="B")
    pdf.multi_cell(190, 5, txt=reminder, align="C")
    pdf.ln(1)

    pdf.output(
        full_QC_model_path + "/Quality Control/" + QC_model_name + "_QC_report.pdf"
    )


def random_fliprot(img, mask, axis=None):
    if axis is None:
        axis = tuple(range(mask.ndim))
    axis = tuple(axis)

    assert img.ndim >= mask.ndim
    perm = tuple(np.random.permutation(axis))
    transpose_axis = np.arange(mask.ndim)
    for a, p in zip(axis, perm):
        transpose_axis[a] = p
    transpose_axis = tuple(transpose_axis)
    img = img.transpose(transpose_axis + tuple(range(mask.ndim, img.ndim)))
    mask = mask.transpose(transpose_axis)
    for ax in axis:
        if np.random.rand() > 0.5:
            img = np.flip(img, axis=ax)
            mask = np.flip(mask, axis=ax)
    return img, mask


def random_intensity_change(img):
    img = img * np.random.uniform(0.6, 2) + np.random.uniform(-0.2, 0.2)
    return img


def augmenter(x, y):
    """Augmentation of a single input/label image pair.
    x is an input image
    y is the corresponding ground-truth label image
    """
    # Note that we only use fliprots along axis=(1,2), i.e. the yx axis
    # as 3D microscopy acquisitions are usually not axially symmetric
    x, y = random_fliprot(x, y, axis=(1, 2))
    x = random_intensity_change(x)
    return x, y


# Build requirements file for local run
after = [str(m) for m in sys.modules]
build_requirements_file(before, after)
print("--------------------------------------")
print(f"Finnished. Duration: {datetime.now() - internal_aux_initial_time}")

# **2. Select your parameters and paths**
---


## **2.1. Setting main training parameters**
---
<font size = 4> 


<font size = 5> **Paths for training, predictions and results**

<font size = 4>**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (images of nuclei) and Training_target (masks) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.

<font size = 4>**`model_name`:** Use only my_model -style, not my-model (Use "_" not "-"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.

<font size = 4>**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).

<font size = 5>**Training parameters**

<font size = 4>**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a 200 epochs, but a full training should run for more. Evaluate the performance after training (see 5.). **Default value: 200**

<font size = 5>**Advanced parameters - experienced users only**

<font size =4>**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 1** 

<font size = 4>**`number_of_steps`:** Define the number of training steps by epoch. By default (or if set to 0) this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**. This value is multiplied by 6 when augmentation is enabled.

<font size = 4>**`patch_size`:** and **`patch_height`:** Input the size of the patches use to train StarDist 3D (length of a side). The value should be smaller or equal to the dimensions of the image. Make patch size and patch_height as large as possible and divisible by 8 and 4, respectively. **Default value: dimension of the training images**

<font size = 4>**`percentage_validation`:**  Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** 

<font size = 4>**`n_rays`:** Set number of rays (corners) used for StarDist (for instance a cube has 8 corners). **Default value: 96** 

<font size = 4>**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**

<font size = 4>**If you get an Out of memory (OOM) error during the training,  manually decrease the patch_size and  patch_height values until the OOM error disappear.**


In [None]:
# Run this cell to visualize the parameters and click the button to execute the code
internal_aux_initial_time = datetime.now()
clear_output()

display(Markdown("### Path to training images: "))
widget_Training_source = widgets.Text(
    value="", style={"description_width": "initial"}, description="Training_source:"
)
display(widget_Training_source)
widget_Training_target = widgets.Text(
    value="", style={"description_width": "initial"}, description="Training_target:"
)
display(widget_Training_target)
display(Markdown("### Name of the model and path to model folder:"))
widget_model_name = widgets.Text(
    value="", style={"description_width": "initial"}, description="model_name:"
)
display(widget_model_name)
widget_model_path = widgets.Text(
    value="", style={"description_width": "initial"}, description="model_path:"
)
display(widget_model_path)
display(Markdown("### Other parameters for training:"))
widget_number_of_epochs = widgets.IntText(
    value=100, style={"description_width": "initial"}, description="number_of_epochs:"
)
display(widget_number_of_epochs)
display(Markdown("### Advanced Parameters"))
widget_Use_Default_Advanced_Parameters = widgets.Checkbox(
    value=True,
    style={"description_width": "initial"},
    description="Use_Default_Advanced_Parameters:",
)
display(widget_Use_Default_Advanced_Parameters)
display(Markdown("### If not, please input:"))
widget_GPU_limit = widgets.IntText(
    value=90, style={"description_width": "initial"}, description="GPU_limit:"
)
display(widget_GPU_limit)
widget_batch_size = widgets.IntText(
    value=2, style={"description_width": "initial"}, description="batch_size:"
)
display(widget_batch_size)
widget_number_of_steps = widgets.IntText(
    value=0, style={"description_width": "initial"}, description="number_of_steps:"
)
display(widget_number_of_steps)
widget_patch_size = widgets.Text(
    value="""128""", style={"description_width": "initial"}, description="patch_size:"
)
display(widget_patch_size)
widget_patch_height = widgets.IntText(
    value=16, style={"description_width": "initial"}, description="patch_height:"
)
display(widget_patch_height)
widget_percentage_validation = widgets.IntText(
    value=10,
    style={"description_width": "initial"},
    description="percentage_validation:",
)
display(widget_percentage_validation)
widget_n_rays = widgets.IntText(
    value=96, style={"description_width": "initial"}, description="n_rays:"
)
display(widget_n_rays)
widget_initial_learning_rate = widgets.FloatText(
    value=0.0003,
    style={"description_width": "initial"},
    description="initial_learning_rate:",
)
display(widget_initial_learning_rate)


def function_9(output_widget):
    output_widget.clear_output()
    with output_widget:
        global Training_source
        global Training_target
        global model_name
        global model_path
        global number_of_epochs
        global Use_Default_Advanced_Parameters
        global GPU_limit
        global batch_size
        global number_of_steps
        global patch_size
        global patch_height
        global percentage_validation
        global n_rays
        global initial_learning_rate

        global training_images
        global mask_images
        global trained_model
        global batch_size
        global n_rays
        global percentage_validation
        global initial_learning_rate
        global patch_size
        global patch_height
        global percentage
        global random_choice
        global x
        global Image_Z
        global mid_plane
        global Image_Y
        global Image_X
        global patch_size
        global patch_size
        global patch_height
        global patch_height
        global patch_height
        global Use_pretrained_model
        global Use_Data_augmentation
        global y
        global norm
        global mid_plane
        global f

        Training_source = widget_Training_source.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_9_Training_source",
            widget_Training_source.value,
        )
        training_images = Training_source

        Training_target = widget_Training_target.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_9_Training_target",
            widget_Training_target.value,
        )
        mask_images = Training_target

        model_name = widget_model_name.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_9_model_name",
            widget_model_name.value,
        )

        model_path = widget_model_path.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_9_model_path",
            widget_model_path.value,
        )
        trained_model = model_path

        number_of_epochs = widget_number_of_epochs.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_9_number_of_epochs",
            widget_number_of_epochs.value,
        )

        Use_Default_Advanced_Parameters = widget_Use_Default_Advanced_Parameters.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_9_Use_Default_Advanced_Parameters",
            widget_Use_Default_Advanced_Parameters.value,
        )

        GPU_limit = widget_GPU_limit.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_9_GPU_limit",
            widget_GPU_limit.value,
        )
        batch_size = widget_batch_size.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_9_batch_size",
            widget_batch_size.value,
        )
        number_of_steps = widget_number_of_steps.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_9_number_of_steps",
            widget_number_of_steps.value,
        )
        patch_size = eval(widget_patch_size.value)
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_9_patch_size",
            eval(widget_patch_size.value),
        )
        patch_height = widget_patch_height.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_9_patch_height",
            widget_patch_height.value,
        )
        percentage_validation = widget_percentage_validation.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_9_percentage_validation",
            widget_percentage_validation.value,
        )
        n_rays = widget_n_rays.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path, "function_9_n_rays", widget_n_rays.value
        )
        initial_learning_rate = widget_initial_learning_rate.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_9_initial_learning_rate",
            widget_initial_learning_rate.value,
        )

        if Use_Default_Advanced_Parameters:
            print("Default advanced parameters enabled")
            batch_size = 2  # default from original author's notebook
            n_rays = 96
            percentage_validation = 10
            initial_learning_rate = 0.0003

            patch_size = 96  # default from original author's notebook
            patch_height = 48  # default from original author's notebook

        percentage = percentage_validation / 100

        # here we check that no model with the same name already exist, if so print a warning

        if os.path.exists(model_path + "/" + model_name):
            print(
                bcolors.WARNING
                + "!! WARNING: "
                + model_name
                + " already exists and will be deleted in the following cell !!"
            )
            print(
                bcolors.WARNING
                + "To continue training "
                + model_name
                + ", choose a new model_name here, and load "
                + model_name
                + " in section 3.3"
                + W
            )

        random_choice = random.choice(os.listdir(Training_source))
        x = imread(Training_source + "/" + random_choice)

        # Here we check that the input images are stacks
        if len(x.shape) == 3:
            print("Image dimensions (z,y,x)", x.shape)

        if not len(x.shape) == 3:
            print(
                bcolors.WARNING
                + "Your images appear to have the wrong dimensions. Image dimension",
                x.shape,
            )

        # Find image Z dimension and select the mid-plane
        Image_Z = x.shape[0]
        mid_plane = int(Image_Z / 2) + 1

        # Find image XY dimension
        Image_Y = x.shape[1]
        Image_X = x.shape[2]

        # If default parameters, patch size is the same as image size
        if Use_Default_Advanced_Parameters:
            patch_size = min(Image_Y, Image_X)
            patch_height = Image_Z

        # Hyperparameters failsafes

        # Here we check that patch_size is smaller than the smallest xy dimension of the image

        if patch_size > min(Image_Y, Image_X):
            patch_size = min(Image_Y, Image_X)
            print(
                bcolors.WARNING
                + " Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:",
                patch_size,
            )

        # Here we check that patch_size is divisible by 8
        if not patch_size % 8 == 0:
            patch_size = (int(patch_size / 8) - 1) * 8
            print(
                bcolors.WARNING
                + " Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:",
                patch_size,
            )

        # Here we check that patch_height is smaller than the z dimension of the image

        if patch_height > Image_Z:
            patch_height = Image_Z
            print(
                bcolors.WARNING
                + " Your chosen patch_height is bigger than the z dimension of your image; therefore the patch_size chosen is now:",
                patch_height,
            )

        # Here we check that patch_height is divisible by 4
        if not patch_height % 4 == 0:
            patch_height = (int(patch_height / 4) - 1) * 4
            if patch_height == 0:
                patch_height = 4
            print(
                bcolors.WARNING
                + " Your chosen patch_height is not divisible by 4; therefore the patch_size chosen is now:",
                patch_height,
            )

        # Here we disable pre-trained model by default (in case the next cell is not ran)
        Use_pretrained_model = False

        # Here we disable data augmentation by default (in case the cell is not ran)

        Use_Data_augmentation = False

        print("Parameters initiated.")

        os.chdir(Training_target)
        y = imread(Training_target + "/" + random_choice)

        # Here we use a simple normalisation strategy to visualise the image
        from astropy.visualization import simple_norm

        norm = simple_norm(x, percent=99)

        mid_plane = int(Image_Z / 2) + 1

        f = plt.figure(figsize=(16, 8))
        plt.subplot(1, 2, 1)
        plt.imshow(x[mid_plane], interpolation="nearest", norm=norm, cmap="magma")
        plt.axis("off")
        plt.title("Training source (single Z plane)")
        plt.subplot(1, 2, 2)
        plt.imshow(y[mid_plane], interpolation="nearest", cmap=lbl_cmap)
        plt.axis("off")
        plt.title("Training target (single Z plane)")
        plt.savefig(
            base_path + "/TrainingDataExample_StarDist3D.png",
            bbox_inches="tight",
            pad_inches=0,
        )

        plt.show()


def function_9_cache(output_widget):
    global Training_source
    global Training_target
    global model_name
    global model_path
    global number_of_epochs
    global Use_Default_Advanced_Parameters
    global GPU_limit
    global batch_size
    global number_of_steps
    global patch_size
    global patch_height
    global percentage_validation
    global n_rays
    global initial_learning_rate

    global training_images
    global mask_images
    global trained_model
    global batch_size
    global n_rays
    global percentage_validation
    global initial_learning_rate
    global patch_size
    global patch_height
    global percentage
    global random_choice
    global x
    global Image_Z
    global mid_plane
    global Image_Y
    global Image_X
    global patch_size
    global patch_size
    global patch_height
    global patch_height
    global patch_height
    global Use_pretrained_model
    global Use_Data_augmentation
    global y
    global norm
    global mid_plane
    global f

    cache_Training_source = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_9_Training_source"
    )
    if cache_Training_source != "":
        widget_Training_source.value = cache_Training_source

    cache_Training_target = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_9_Training_target"
    )
    if cache_Training_target != "":
        widget_Training_target.value = cache_Training_target

    cache_model_name = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_9_model_name"
    )
    if cache_model_name != "":
        widget_model_name.value = cache_model_name

    cache_model_path = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_9_model_path"
    )
    if cache_model_path != "":
        widget_model_path.value = cache_model_path

    cache_number_of_epochs = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_9_number_of_epochs"
    )
    if cache_number_of_epochs != "":
        widget_number_of_epochs.value = cache_number_of_epochs

    cache_Use_Default_Advanced_Parameters = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_9_Use_Default_Advanced_Parameters"
    )
    if cache_Use_Default_Advanced_Parameters != "":
        widget_Use_Default_Advanced_Parameters.value = (
            cache_Use_Default_Advanced_Parameters
        )

    cache_GPU_limit = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_9_GPU_limit"
    )
    if cache_GPU_limit != "":
        widget_GPU_limit.value = cache_GPU_limit

    cache_batch_size = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_9_batch_size"
    )
    if cache_batch_size != "":
        widget_batch_size.value = cache_batch_size

    cache_number_of_steps = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_9_number_of_steps"
    )
    if cache_number_of_steps != "":
        widget_number_of_steps.value = cache_number_of_steps

    cache_patch_size = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_9_patch_size"
    )
    if cache_patch_size != "":
        widget_patch_size.value = cache_patch_size

    cache_patch_height = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_9_patch_height"
    )
    if cache_patch_height != "":
        widget_patch_height.value = cache_patch_height

    cache_percentage_validation = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_9_percentage_validation"
    )
    if cache_percentage_validation != "":
        widget_percentage_validation.value = cache_percentage_validation

    cache_n_rays = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_9_n_rays"
    )
    if cache_n_rays != "":
        widget_n_rays.value = cache_n_rays

    cache_initial_learning_rate = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_9_initial_learning_rate"
    )
    if cache_initial_learning_rate != "":
        widget_initial_learning_rate.value = cache_initial_learning_rate


button_function_9 = widgets.Button(description="Load and run")
cache_button_function_9 = widgets.Button(description="Load prev. settings")
output_function_9 = widgets.Output()
display(widgets.HBox((button_function_9, cache_button_function_9)), output_function_9)


def aux_function_9(_):
    return function_9(output_function_9)


def aux_function_9_cache(_):
    return function_9_cache(output_function_9)


button_function_9.on_click(aux_function_9)
cache_button_function_9.on_click(aux_function_9_cache)
print("--------------------------------------------------------------")
print('^ Introduce the arguments and click "Load and run". ^')
print('^ Or first click "Load prev. settings" if any previous ^')
print('^ settings have been saved and then click "Load and run". ^')

## **2.2. Data augmentation**
---
<font size = 4>


<font size = 4>Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.

<font size = 4><font size = 4> **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.**  

<font size = 4>Data augmentation is performed here by flipping, rotating and modifying the intensity of the images.




In [None]:
# Run this cell to visualize the parameters and click the button to execute the code
internal_aux_initial_time = datetime.now()
clear_output()

widget_Use_Data_augmentation = widgets.Checkbox(
    value=False,
    style={"description_width": "initial"},
    description="Use_Data_augmentation:",
)
display(widget_Use_Data_augmentation)


def function_12(output_widget):
    output_widget.clear_output()
    with output_widget:
        global Use_Data_augmentation

        global augmenter
        global augmenter

        # Data augmentation

        Use_Data_augmentation = widget_Use_Data_augmentation.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_12_Use_Data_augmentation",
            widget_Use_Data_augmentation.value,
        )

        if Use_Data_augmentation:
            augmenter = augmenter
            print("Data augmentation enabled. Let's flip!")
        else:
            augmenter = None
            print("Data augmentation disabled.")

        plt.show()


def function_12_cache(output_widget):
    global Use_Data_augmentation

    global augmenter
    global augmenter

    cache_Use_Data_augmentation = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_12_Use_Data_augmentation"
    )
    if cache_Use_Data_augmentation != "":
        widget_Use_Data_augmentation.value = cache_Use_Data_augmentation


button_function_12 = widgets.Button(description="Load and run")
cache_button_function_12 = widgets.Button(description="Load prev. settings")
output_function_12 = widgets.Output()
display(
    widgets.HBox((button_function_12, cache_button_function_12)), output_function_12
)


def aux_function_12(_):
    return function_12(output_function_12)


def aux_function_12_cache(_):
    return function_12_cache(output_function_12)


button_function_12.on_click(aux_function_12)
cache_button_function_12.on_click(aux_function_12_cache)
print("--------------------------------------------------------------")
print('^ Introduce the arguments and click "Load and run". ^')
print('^ Or first click "Load prev. settings" if any previous ^')
print('^ settings have been saved and then click "Load and run". ^')


## **2.3. Using weights from a pre-trained model as initial weights**
---
<font size = 4>  Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a StarDist model**. 

<font size = 4> This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.

<font size = 4> In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. 


In [None]:
# Run this cell to visualize the parameters and click the button to execute the code
internal_aux_initial_time = datetime.now()
clear_output()

display(Markdown("## Loading weights from a pre-trained network"))
widget_Use_pretrained_model = widgets.Checkbox(
    value=True,
    style={"description_width": "initial"},
    description="Use_pretrained_model:",
)
display(widget_Use_pretrained_model)
widget_pretrained_model_choice = widgets.Dropdown(
    options=["Model_from_file", "Demo_3D_Model_from_Stardist_3D_paper"],
    value="Demo_3D_Model_from_Stardist_3D_paper",
    style={"description_width": "initial"},
    description="pretrained_model_choice:",
)
display(widget_pretrained_model_choice)
widget_Weights_choice = widgets.Dropdown(
    options=["last", "best"],
    value="best",
    style={"description_width": "initial"},
    description="Weights_choice:",
)
display(widget_Weights_choice)
display(
    Markdown(
        "### If you chose 'Model_from_file', please provide the path to the model folder:"
    )
)
widget_pretrained_model_path = widgets.Text(
    value="",
    style={"description_width": "initial"},
    description="pretrained_model_path:",
)
display(widget_pretrained_model_path)


def function_14(output_widget):
    output_widget.clear_output()
    with output_widget:
        global Use_pretrained_model
        global pretrained_model_choice
        global Weights_choice
        global pretrained_model_path

        global h5_file_path
        global pretrained_model_name
        global pretrained_model_path
        global h5_file_path
        global Use_pretrained_model
        global csvRead
        global lastLearningRate
        global min_val_loss
        global bestLearningRate
        global bestLearningRate
        global lastLearningRate
        global bestLearningRate
        global lastLearningRate

        Use_pretrained_model = widget_Use_pretrained_model.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_14_Use_pretrained_model",
            widget_Use_pretrained_model.value,
        )

        pretrained_model_choice = widget_pretrained_model_choice.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_14_pretrained_model_choice",
            widget_pretrained_model_choice.value,
        )

        Weights_choice = widget_Weights_choice.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_14_Weights_choice",
            widget_Weights_choice.value,
        )

        pretrained_model_path = widget_pretrained_model_path.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_14_pretrained_model_path",
            widget_pretrained_model_path.value,
        )

        # --------------------- Check if we load a previously trained model ------------------------
        if Use_pretrained_model:
            # --------------------- Load the model from the choosen path ------------------------
            if pretrained_model_choice == "Model_from_file":
                h5_file_path = os.path.join(
                    pretrained_model_path, "weights_" + Weights_choice + ".h5"
                )

            # --------------------- Download the Demo 3D model provided in the Stardist 3D github ------------------------

            if pretrained_model_choice == "Demo_3D_Model_from_Stardist_3D_paper":
                pretrained_model_name = "Demo_3D"
                pretrained_model_path = base_path + "/" + pretrained_model_name
                print("Downloading the Demo 3D model from the Stardist_3D paper")
                if os.path.exists(pretrained_model_path):
                    shutil.rmtree(pretrained_model_path)
                os.makedirs(pretrained_model_path)
                wget.download(
                    "https://raw.githubusercontent.com/stardist/stardist/main/models/examples/3D_demo/config.json",
                    pretrained_model_path,
                )
                wget.download(
                    "https://github.com/stardist/stardist/raw/main/models/examples/3D_demo/thresholds.json",
                    pretrained_model_path,
                )
                wget.download(
                    "https://github.com/stardist/stardist/blob/main/models/examples/3D_demo/weights_best.h5?raw=true",
                    pretrained_model_path,
                )
                wget.download(
                    "https://github.com/stardist/stardist/blob/main/models/examples/3D_demo/weights_last.h5?raw=true",
                    pretrained_model_path,
                )
                h5_file_path = os.path.join(
                    pretrained_model_path, "weights_" + Weights_choice + ".h5"
                )

            # --------------------- Add additional pre-trained models here ------------------------

            # --------------------- Check the model exist ------------------------
            # If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled,
            if not os.path.exists(h5_file_path):
                print(
                    bcolors.WARNING
                    + "WARNING: weights_last.h5 pretrained model does not exist"
                    + W
                )
                Use_pretrained_model = False

            # If the model path contains a pretrain model, we load the training rate,
            if os.path.exists(h5_file_path):
                # Here we check if the learning rate can be loaded from the quality control folder
                if os.path.exists(
                    os.path.join(
                        pretrained_model_path,
                        "Quality Control",
                        "training_evaluation.csv",
                    )
                ):
                    with open(
                        os.path.join(
                            pretrained_model_path,
                            "Quality Control",
                            "training_evaluation.csv",
                        ),
                        "r",
                    ) as csvfile:
                        csvRead = pd.read_csv(csvfile, sep=",")
                        # print(csvRead)

                        if (
                            "learning rate" in csvRead.columns
                        ):  # Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)
                            print("pretrained network learning rate found")
                            # find the last learning rate
                            lastLearningRate = csvRead["learning rate"].iloc[-1]
                            # Find the learning rate corresponding to the lowest validation loss
                            min_val_loss = csvRead[
                                csvRead["val_loss"] == min(csvRead["val_loss"])
                            ]
                            # print(min_val_loss)
                            bestLearningRate = min_val_loss["learning rate"].iloc[-1]

                            if Weights_choice == "last":
                                print("Last learning rate: " + str(lastLearningRate))

                            if Weights_choice == "best":
                                print(
                                    "Learning rate of best validation loss: "
                                    + str(bestLearningRate)
                                )

                        if (
                            not "learning rate" in csvRead.columns
                        ):  # if the column does not exist, then initial learning rate is used instead
                            bestLearningRate = initial_learning_rate
                            lastLearningRate = initial_learning_rate
                            print(
                                bcolors.WARNING
                                + "WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of "
                                + str(bestLearningRate)
                                + " will be used instead"
                                + W
                            )

                # Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used
                if not os.path.exists(
                    os.path.join(
                        pretrained_model_path,
                        "Quality Control",
                        "training_evaluation.csv",
                    )
                ):
                    print(
                        bcolors.WARNING
                        + "WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of "
                        + str(initial_learning_rate)
                        + " will be used instead"
                        + W
                    )
                    bestLearningRate = initial_learning_rate
                    lastLearningRate = initial_learning_rate

        # Display info about the pretrained model to be loaded (or not)
        if Use_pretrained_model:
            print(bcolors.WARNING + "Weights found in:")
            print(h5_file_path)
            print(bcolors.WARNING + "will be loaded prior to training.")

        else:
            print(bcolors.WARNING + "No pretrained network will be used.")

        plt.show()


def function_14_cache(output_widget):
    global Use_pretrained_model
    global pretrained_model_choice
    global Weights_choice
    global pretrained_model_path

    global h5_file_path
    global pretrained_model_name
    global pretrained_model_path
    global h5_file_path
    global Use_pretrained_model
    global csvRead
    global lastLearningRate
    global min_val_loss
    global bestLearningRate
    global bestLearningRate
    global lastLearningRate
    global bestLearningRate
    global lastLearningRate

    cache_Use_pretrained_model = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_14_Use_pretrained_model"
    )
    if cache_Use_pretrained_model != "":
        widget_Use_pretrained_model.value = cache_Use_pretrained_model

    cache_pretrained_model_choice = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_14_pretrained_model_choice"
    )
    if cache_pretrained_model_choice != "":
        widget_pretrained_model_choice.value = cache_pretrained_model_choice

    cache_Weights_choice = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_14_Weights_choice"
    )
    if cache_Weights_choice != "":
        widget_Weights_choice.value = cache_Weights_choice

    cache_pretrained_model_path = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_14_pretrained_model_path"
    )
    if cache_pretrained_model_path != "":
        widget_pretrained_model_path.value = cache_pretrained_model_path


button_function_14 = widgets.Button(description="Load and run")
cache_button_function_14 = widgets.Button(description="Load prev. settings")
output_function_14 = widgets.Output()
display(
    widgets.HBox((button_function_14, cache_button_function_14)), output_function_14
)


def aux_function_14(_):
    return function_14(output_function_14)


def aux_function_14_cache(_):
    return function_14_cache(output_function_14)


button_function_14.on_click(aux_function_14)
cache_button_function_14.on_click(aux_function_14_cache)
print("--------------------------------------------------------------")
print('^ Introduce the arguments and click "Load and run". ^')
print('^ Or first click "Load prev. settings" if any previous ^')
print('^ settings have been saved and then click "Load and run". ^')

# **3. Train the network**
---


## **3.1. Prepare the training data and model for training**
---
<font size = 4>Here, we use the information from 3. to build the model and convert the training data into a suitable format for training.


In [None]:
import keras

# Run this cell to execute the code
internal_aux_initial_time = datetime.now()
print("Runnning...")
print("--------------------------------------")

# --------------------- Here we delete the model folder if it already exist ------------------------

if os.path.exists(model_path + "/" + model_name):
    print(
        bcolors.WARNING
        + "!! WARNING: Model folder already exists and has been removed !!"
        + W
    )
    shutil.rmtree(model_path + "/" + model_name)

import warnings

warnings.simplefilter("ignore")

# --------------------- Here we load the augmented data or the raw data ------------------------

# if Use_Data_augmentation:
#   Training_source_dir = Training_source_augmented
#   Training_target_dir = Training_target_augmented

# if not Use_Data_augmentation:
#   Training_source_dir = Training_source
#   Training_target_dir = Training_target
# --------------------- ------------------------------------------------

Training_source_dir = Training_source
Training_target_dir = Training_target
training_images_tiff = Training_source_dir + "/*.tif"
mask_images_tiff = Training_target_dir + "/*.tif"

# this funtion imports training images and masks and sorts them suitable for the network
X_imgs = sorted(glob(training_images_tiff))
Y_lbls = sorted(glob(mask_images_tiff))

X = X_imgs

# assert -funtion check that X and Y really have images. If not this cell raises an error
assert all(Path(x).name == Path(y).name for x, y in zip(X_imgs, Y_lbls))

# Here we split the your training dataset into training images (90 %) and validation images (10 %).

assert len(X_imgs) > 1, "not enough training data"
rng = np.random.RandomState(42)
ind = rng.permutation(len(X_imgs))
n_val = max(1, int(round(percentage * len(ind))))
# init the lists for the indexes of training and validation data

ind_train = []
ind_val = []

Y_extents = list(map(imread, Y_lbls))

for idx in ind:
    if Y_extents[idx].any() > 0 and len(ind_val) < n_val:
        ind_val.append(idx)
    else:
        ind_train.append(idx)

print("number of images: %3d" % len(X_imgs))
print("- training:       %3d" % len(ind_train))
print("- validation:     %3d" % len(ind_val))

# Make list of filepaths of validation data
X_val, Y_val = (
    [imread(X_imgs[i]) for i in ind_val],
    [imread(Y_lbls[i]) for i in ind_val],
)

extents = calculate_extents(Y_extents)
anisotropy = tuple(np.max(extents) / extents)
print("empirical anisotropy of labeled objects = %s" % str(anisotropy))

n_channel = 1 if Y_extents[0].ndim == 3 else Y_extents[0].shape[-1]

# Use OpenCL-based computations for data generator during training (requires 'gputools')
use_gpu = False and gputools_available()

# Here we ensure that our network has a minimal number of steps
if (Use_Default_Advanced_Parameters) or (number_of_steps == 0):
    number_of_steps = (Image_X // patch_size) * (Image_Y // patch_size) * (
        Image_Z // patch_height
    ) * int(len(X_imgs) / batch_size) + 1
    if Use_Data_augmentation:
        number_of_steps = number_of_steps * 6

print("Number of steps: " + str(number_of_steps))

# remove validation pairs from list
X_imgs = [img for n, img in enumerate(X_imgs) if n not in ind_val]
Y_lbls = [img for n, img in enumerate(Y_lbls) if n not in ind_val]

# --------------------- Using pretrained model ------------------------
# Here we ensure that the learning rate set correctly when using pre-trained models
if Use_pretrained_model:
    if Weights_choice == "last":
        initial_learning_rate = lastLearningRate

    if Weights_choice == "best":
        initial_learning_rate = bestLearningRate
# --------------------- ---------------------- ------------------------

# Predict on subsampled grid for increased efficiency and larger field of view
grid = tuple(1 if a > 1.5 else 2 for a in anisotropy)

# Use rays on a Fibonacci lattice adjusted for measured anisotropy of the training data
rays = Rays_GoldenSpiral(n_rays, anisotropy=anisotropy)

conf = Config3D(
    rays=rays,
    grid=(1, 1, 1),
    anisotropy=anisotropy,
    use_gpu=use_gpu,
    n_channel_in=n_channel,
    train_learning_rate=initial_learning_rate,
    train_patch_size=(patch_height, patch_size, patch_size),
    train_batch_size=batch_size,
)
print(conf)
vars(conf)

# --------------------- This is currently disabled as it give an error ------------------------
# here we limit GPU to 80%
if use_gpu:
    from csbdeep.utils.tf import limit_gpu_memory

    # adjust as necessary: limit GPU memory to be used by TensorFlow to leave some to OpenCL-based computations
    limit_gpu_memory(0.8)
# --------------------- ---------------------- ------------------------

# Here we create a model according to section 5.3.
model = StarDist3D(conf, name=model_name, basedir=trained_model)

# --------------------- Using pretrained model ------------------------
# Load the pretrained weights
if Use_pretrained_model:
    model.load_weights(h5_file_path)
# --------------------- ---------------------- ------------------------

# Here we check the FOV of the network.
median_size = calculate_extents(Y_extents, np.median)
fov = np.array(model._axes_tile_overlap("ZYX"))
if any(median_size > fov):
    print(
        "WARNING: median object size larger than field of view of the neural network."
    )

del Y_extents

# --------------------- Here we create the data generators ------------------------
# Create data generators for training data
# Validation data has already been loaded to memory above


# Data generator class for training images
class StardistSequence_X(keras.utils.Sequence):
    def __init__(self, X_imgs):
        self.training_images_tiff = X_imgs

    def __len__(self):
        return int(np.ceil(len(self.training_images_tiff)))

    def __getitem__(self, idx):
        k = 0
        X = []

        while len(X) < 1 and (idx + k) < len(self.training_images_tiff):
            X = imread(self.training_images_tiff[idx + k])
            k += 1
        return X


# Data generator class for training labels
class StardistSequence_Y(keras.utils.Sequence):
    def __init__(self, Y_lbls):
        self.mask_images_tiff = Y_lbls

    def __len__(self):
        return int(np.ceil(len(self.mask_images_tiff)))

    def __getitem__(self, idx):
        k = 0
        Y = []

        while len(Y) < 1 and (idx + k) < len(self.mask_images_tiff):
            Y = imread(self.mask_images_tiff[idx + k])
            Y = fill_label_holes(Y)
            k += 1
        return Y


# Create the data generators to be used in the training
generator_x = StardistSequence_X(X_imgs)
generator_y = StardistSequence_Y(Y_lbls)

pdf_export(augmentation=Use_Data_augmentation, pretrained_model=Use_pretrained_model)
print("--------------------------------------")
print(f"Finished. Duration: {datetime.now() - internal_aux_initial_time}")


## **3.2. Start Training**
---

<font size = 4>When playing the cell below you should see updates after each epoch (round). Network training can take some time.

<font size = 4>* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.

<font size = 4>Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder.


In [None]:
# Run this cell to execute the code
internal_aux_initial_time = datetime.now()
print("Runnning...")
print("--------------------------------------")
import time

start = time.time()

import warnings

warnings.filterwarnings("ignore")

# augmenter = None

# def augmenter(X_batch, Y_batch):
#     """Augmentation for data batch.
#     X_batch is a list of input images (length at most batch_size)
#     Y_batch is the corresponding list of ground-truth label images
#     """
#     # ...
#     return X_batch, Y_batch

# Training the model.
# 'input_epochs' and 'steps' refers to your input data in section 5.1
history = model.train(
    generator_x,
    generator_y,
    validation_data=(X_val, Y_val),
    augmenter=augmenter,
    epochs=number_of_epochs,
    steps_per_epoch=number_of_steps,
)
None
print("Training done")

# convert the history.history dict to a pandas DataFrame:
lossData = pd.DataFrame(history.history)

if os.path.exists(model_path + "/" + model_name + "/Quality Control"):
    shutil.rmtree(model_path + "/" + model_name + "/Quality Control")

os.makedirs(model_path + "/" + model_name + "/Quality Control")

# The training evaluation.csv is saved (overwrites the Files if needed).
lossDataCSVpath = (
    model_path + "/" + model_name + "/Quality Control/training_evaluation.csv"
)
with open(lossDataCSVpath, "w") as f:
    writer = csv.writer(f)
    writer.writerow(["loss", "val_loss", "learning rate"])
    for i in range(len(history.history["loss"])):
        writer.writerow(
            [
                history.history["loss"][i],
                history.history["val_loss"][i],
                history.history["lr"][i],
            ]
        )

print("Network optimization in progress")

# Here we optimize the network.
model.optimize_thresholds(X_val, Y_val)
print("Done")

# Displaying the time elapsed for training
dt = time.time() - start
mins, sec = divmod(dt, 60)
hour, mins = divmod(mins, 60)
print("Time elapsed:", hour, "hour(s)", mins, "min(s)", round(sec), "sec(s)")

# Create a pdf document with training summary
pdf_export(
    trained=True,
    augmentation=Use_Data_augmentation,
    pretrained_model=Use_pretrained_model,
)

print("--------------------------------------")
print(f"Finnished. Duration: {datetime.now() - internal_aux_initial_time}")

# **4. Evaluate your model**
---

<font size = 4>This section allows the user to perform important quality checks on the validity and generalisability of the trained model. 

<font size = 4>**We highly recommend to perform quality control on all newly trained models.**




In [None]:
# Run this cell to visualize the parameters and click the button to execute the code
internal_aux_initial_time = datetime.now()
clear_output()

display(Markdown("### Do you want to assess the model you just trained ?"))
widget_Use_the_current_trained_model = widgets.Checkbox(
    value=True,
    style={"description_width": "initial"},
    description="Use_the_current_trained_model:",
)
display(widget_Use_the_current_trained_model)
display(Markdown("### If not, please provide the path to the model folder:"))
widget_QC_model_folder = widgets.Text(
    value="", style={"description_width": "initial"}, description="QC_model_folder:"
)
display(widget_QC_model_folder)


def function_21(output_widget):
    output_widget.clear_output()
    with output_widget:
        global Use_the_current_trained_model
        global QC_model_folder

        global QC_model_name
        global QC_model_path
        global QC_model_name
        global QC_model_path
        global full_QC_model_path
        global W
        global R

        # model name and path
        Use_the_current_trained_model = widget_Use_the_current_trained_model.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_21_Use_the_current_trained_model",
            widget_Use_the_current_trained_model.value,
        )

        QC_model_folder = widget_QC_model_folder.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_21_QC_model_folder",
            widget_QC_model_folder.value,
        )

        # Here we define the loaded model name and path
        QC_model_name = os.path.basename(QC_model_folder)
        QC_model_path = os.path.dirname(QC_model_folder)

        if Use_the_current_trained_model:
            QC_model_name = model_name
            QC_model_path = model_path

        full_QC_model_path = QC_model_path + "/" + QC_model_name + "/"
        if os.path.exists(full_QC_model_path):
            print("The " + QC_model_name + " network will be evaluated")
        else:
            W = "\033[0m"  # white (normal)
            R = "\033[31m"  # red
            print(R + "!! WARNING: The chosen model does not exist !!" + W)
            print(
                "Please make sure you provide a valid model path and model name before proceeding further."
            )

        plt.show()


def function_21_cache(output_widget):
    global Use_the_current_trained_model
    global QC_model_folder

    global QC_model_name
    global QC_model_path
    global QC_model_name
    global QC_model_path
    global full_QC_model_path
    global W
    global R

    cache_Use_the_current_trained_model = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_21_Use_the_current_trained_model"
    )
    if cache_Use_the_current_trained_model != "":
        widget_Use_the_current_trained_model.value = cache_Use_the_current_trained_model

    cache_QC_model_folder = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_21_QC_model_folder"
    )
    if cache_QC_model_folder != "":
        widget_QC_model_folder.value = cache_QC_model_folder


button_function_21 = widgets.Button(description="Load and run")
cache_button_function_21 = widgets.Button(description="Load prev. settings")
output_function_21 = widgets.Output()
display(
    widgets.HBox((button_function_21, cache_button_function_21)), output_function_21
)


def aux_function_21(_):
    return function_21(output_function_21)


def aux_function_21_cache(_):
    return function_21_cache(output_function_21)


button_function_21.on_click(aux_function_21)
cache_button_function_21.on_click(aux_function_21_cache)
print("--------------------------------------------------------------")
print('^ Introduce the arguments and click "Load and run". ^')
print('^ Or first click "Load prev. settings" if any previous ^')
print('^ settings have been saved and then click "Load and run". ^')

If accessing the just trained model, you can skip the cell below and go directly to the next Section.

In [None]:
# Check and set default values for patch parameters
try:
    patch_size
except NameError:
    patch_size = 128
    print(f"patch_size not found, setting default value: {patch_size}")

try:
    patch_height
except NameError:
    patch_height = 16
    print(f"patch_height not found, setting default value: {patch_height}")

## **4.1. Inspection of the loss function**
---

<font size = 4>First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*

<font size = 4>**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.

<font size = 4>**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.

<font size = 4>During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.

<font size = 4>Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.


In [None]:
# Run this cell to execute the code
internal_aux_initial_time = datetime.now()
print("Runnning...")
print("--------------------------------------")
import csv
from matplotlib import pyplot as plt

lossDataFromCSV = []
vallossDataFromCSV = []

with open(
    QC_model_path + "/" + QC_model_name + "/Quality Control/training_evaluation.csv",
    "r",
) as csvfile:
    csvRead = csv.reader(csvfile, delimiter=",")
    next(csvRead)
    for row in csvRead:
        lossDataFromCSV.append(float(row[0]))
        vallossDataFromCSV.append(float(row[1]))

epochNumber = range(len(lossDataFromCSV))
plt.figure(figsize=(15, 10))

plt.subplot(2, 1, 1)
plt.plot(epochNumber, lossDataFromCSV, label="Training loss")
plt.plot(epochNumber, vallossDataFromCSV, label="Validation loss")
plt.title("Training loss and validation loss vs. epoch number (linear scale)")
plt.ylabel("Loss")
plt.xlabel("Epoch number")
plt.legend()

plt.subplot(2, 1, 2)
plt.semilogy(epochNumber, lossDataFromCSV, label="Training loss")
plt.semilogy(epochNumber, vallossDataFromCSV, label="Validation loss")
plt.title("Training loss and validation loss vs. epoch number (log scale)")
plt.ylabel("Loss")
plt.xlabel("Epoch number")
plt.legend()
plt.savefig(
    QC_model_path + "/" + QC_model_name + "/Quality Control/lossCurvePlots.png",
    bbox_inches="tight",
    pad_inches=0,
)
plt.show()

print("--------------------------------------")
print(f"Finnished. Duration: {datetime.now() - internal_aux_initial_time}")

## **4.2. Error mapping and quality metrics estimation**
---
<font size = 4>This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder ! The result for one of the image will also be displayed.

<font size = 4>The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. 

<font size = 4> The results can be found in the "*Quality Control*" folder which is located inside your "model_folder".


### WARNING:
This section is designed to be used with **NON-NORMALIZED** images. If you have normalized your images, you should should disable the normalization step in the cell below.

In [None]:
# Run this cell to visualize the parameters and click the button to execute the code
internal_aux_initial_time = datetime.now()
clear_output()

display(
    Markdown("## Give the paths to an image to test the performance of the model with.")
)
widget_Source_QC_folder = widgets.Text(
    value="", style={"description_width": "initial"}, description="Source_QC_folder:"
)
display(widget_Source_QC_folder)
widget_Target_QC_folder = widgets.Text(
    value="", style={"description_width": "initial"}, description="Target_QC_folder:"
)
display(widget_Target_QC_folder)
display(
    Markdown(
        "##### To analyse large image, your images need to be divided into tiles.  Each tile will then be processed independently and re-assembled to generate the final image. 'Automatic_number_of_tiles' will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. "
    )
)
widget_Automatic_number_of_tiles = widgets.Checkbox(
    value=False,
    style={"description_width": "initial"},
    description="Automatic_number_of_tiles:",
)
display(widget_Automatic_number_of_tiles)
display(
    Markdown(
        "##### If you get an Out of memory (OOM) error when using the 'Automatic_number_of_tiles' option, disable it and manually input the values to be used to process your images.  Progressively increases these numbers until the OOM error disappear."
    )
)
widget_n_tiles_Z = widgets.IntText(
    value=1, style={"description_width": "initial"}, description="n_tiles_Z:"
)
display(widget_n_tiles_Z)
widget_n_tiles_Y = widgets.IntText(
    value=1, style={"description_width": "initial"}, description="n_tiles_Y:"
)
display(widget_n_tiles_Y)
widget_n_tiles_X = widgets.IntText(
    value=1, style={"description_width": "initial"}, description="n_tiles_X:"
)
display(widget_n_tiles_X)


def function_25(output_widget):
    output_widget.clear_output()
    with output_widget:
        global Source_QC_folder
        global Target_QC_folder
        global Automatic_number_of_tiles
        global n_tiles_Z
        global n_tiles_Y
        global n_tiles_X

        global n_tilesZYX
        global n_tilesZYX
        global Source_QC_folder_tif
        global lbl_cmap
        global Z
        global Z
        global n_channel
        global axis_norm
        global model
        global names
        global lenght_of_Z
        global img
        global labels
        global polygons
        global writer
        global filename_list
        global IoU_score_list
        global test_input
        global test_prediction
        global test_ground_truth_image
        global test_prediction_0_to_255
        global test_ground_truth_0_to_255
        global intersection
        global union
        global iou_score
        global pdResults
        global f
        global test_input
        global test_prediction
        global test_ground_truth_image
        global norm
        global Image_Z
        global mid_plane
        global test_prediction_0_to_255
        global test_ground_truth_0_to_255

        global show_QC_results

        import warnings

        warnings.filterwarnings("ignore")

        Source_QC_folder = widget_Source_QC_folder.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_25_Source_QC_folder",
            widget_Source_QC_folder.value,
        )
        Target_QC_folder = widget_Target_QC_folder.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_25_Target_QC_folder",
            widget_Target_QC_folder.value,
        )

        # Here we allow the user to choose the number of tile to be used when predicting the images

        Automatic_number_of_tiles = widget_Automatic_number_of_tiles.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_25_Automatic_number_of_tiles",
            widget_Automatic_number_of_tiles.value,
        )
        n_tiles_Z = widget_n_tiles_Z.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_25_n_tiles_Z",
            widget_n_tiles_Z.value,
        )
        n_tiles_Y = widget_n_tiles_Y.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_25_n_tiles_Y",
            widget_n_tiles_Y.value,
        )
        n_tiles_X = widget_n_tiles_X.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_25_n_tiles_X",
            widget_n_tiles_X.value,
        )

        if Automatic_number_of_tiles:
            n_tilesZYX = None

        if not (Automatic_number_of_tiles):
            n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)

        # Create a quality control Folder and check if the folder already exist
        if (
            os.path.exists(QC_model_path + "/" + QC_model_name + "/Quality Control")
            == False
        ):
            os.makedirs(QC_model_path + "/" + QC_model_name + "/Quality Control")

        if os.path.exists(
            QC_model_path + "/" + QC_model_name + "/Quality Control/Prediction"
        ):
            shutil.rmtree(
                QC_model_path + "/" + QC_model_name + "/Quality Control/Prediction"
            )

        os.makedirs(QC_model_path + "/" + QC_model_name + "/Quality Control/Prediction")

        # Generate predictions from the Source_QC_folder and save them in the QC folder

        Source_QC_folder_tif = Source_QC_folder + "/*.tif"

        np.random.seed(16)
        lbl_cmap = random_label_cmap()
        Z = sorted(glob(Source_QC_folder_tif))
        axis_norm = (0, 1, 2)  # normalize whole image stack jointly
        n_channel = 1

        print("Number of test dataset found in the folder: " + str(len(Z)))

        model = StarDist3D(None, name=QC_model_name, basedir=QC_model_path)

        names = [os.path.basename(f) for f in sorted(glob(Source_QC_folder_tif))]

        # modify the names to suitable form: path_images/image_numberX.tif

        lenght_of_Z = len(Z)

        for i in range(lenght_of_Z):
            img = imread(Z[i])

            # ---------Normalize the image-----------------
            # Comment this line if you want to use the original image
            img = normalize(img, 1, 99.8, axis=axis_norm)
            # ----------------------------------------------

            n_tilesZYX = (
                int(img.shape[0] / patch_height),
                int(img.shape[1] / patch_size),
                int(img.shape[2] / patch_size),
            )

            labels, polygons = model.predict_instances(img, n_tiles=n_tilesZYX)
            os.chdir(
                QC_model_path + "/" + QC_model_name + "/Quality Control/Prediction"
            )
            save_tiff_imagej_compatible(names[i], labels, axes="ZYX")

        # Here we start testing the differences between GT and predicted masks

        with open(
            QC_model_path
            + "/"
            + QC_model_name
            + "/Quality Control/Quality_Control for "
            + QC_model_name
            + ".csv",
            "w",
            newline="",
        ) as file:
            writer = csv.writer(file)
            writer.writerow(["image", "Prediction v. GT Intersection over Union"])

            # Initialise the lists
            filename_list = []
            IoU_score_list = []

            for n in os.listdir(Source_QC_folder):
                if not os.path.isdir(os.path.join(Source_QC_folder, n)):
                    if n in os.listdir(
                        QC_model_path
                        + "/"
                        + QC_model_name
                        + "/Quality Control/Prediction"
                    ):
                        print("Running QC on: " + n)

                        test_input = io.imread(os.path.join(Source_QC_folder, n))
                        test_prediction = io.imread(
                            os.path.join(
                                QC_model_path
                                + "/"
                                + QC_model_name
                                + "/Quality Control/Prediction",
                                n,
                            )
                        )
                        test_ground_truth_image = io.imread(
                            os.path.join(Target_QC_folder, n)
                        )

                        # Convert pixel values to 0 or 255
                        test_prediction_0_to_255 = test_prediction
                        test_prediction_0_to_255[test_prediction_0_to_255 > 0] = 255

                        # Convert pixel values to 0 or 255
                        test_ground_truth_0_to_255 = test_ground_truth_image
                        test_ground_truth_0_to_255[test_ground_truth_0_to_255 > 0] = 255

                        # Intersection over Union metric

                        intersection = np.logical_and(
                            test_ground_truth_0_to_255, test_prediction_0_to_255
                        )
                        union = np.logical_or(
                            test_ground_truth_0_to_255, test_prediction_0_to_255
                        )
                        iou_score = np.sum(intersection) / np.sum(union)
                        writer.writerow([n, str(iou_score)])

                        print("IoU: " + str(round(iou_score, 3)))

                        filename_list.append(n)
                        IoU_score_list.append(iou_score)

        # Table with metrics as dataframe output
        pdResults = pd.DataFrame(index=filename_list)
        pdResults["IoU"] = IoU_score_list

        # Display results
        pdResults.head()

        # ------------- For display ------------
        print("--------------------------------------------------------------")

        @interact
        def show_QC_results(file=os.listdir(Source_QC_folder)):
            f = plt.figure(figsize=(32, 8))

            test_input = io.imread(os.path.join(Source_QC_folder, file))
            test_prediction = io.imread(
                os.path.join(
                    QC_model_path + "/" + QC_model_name + "/Quality Control/Prediction",
                    file,
                )
            )
            test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, file))

            norm = simple_norm(test_input, percent=99)
            Image_Z = test_input.shape[0]
            mid_plane = int(Image_Z / 2) + 1

            # Convert pixel values to 0 or 255
            test_prediction_0_to_255 = test_prediction
            test_prediction_0_to_255[test_prediction_0_to_255 > 0] = 255

            # Convert pixel values to 0 or 255
            test_ground_truth_0_to_255 = test_ground_truth_image
            test_ground_truth_0_to_255[test_ground_truth_0_to_255 > 0] = 255

            # Input
            plt.subplot(1, 4, 1)
            plt.axis("off")
            plt.imshow(
                test_input[mid_plane],
                aspect="equal",
                norm=norm,
                cmap="magma",
                interpolation="nearest",
            )
            plt.title("Input")

            # Ground-truth
            plt.subplot(1, 4, 2)
            plt.axis("off")
            plt.imshow(
                test_ground_truth_0_to_255[mid_plane], aspect="equal", cmap="Greens"
            )
            plt.title("Ground Truth")

            # Prediction
            plt.subplot(1, 4, 3)
            plt.axis("off")
            plt.imshow(
                test_prediction_0_to_255[mid_plane], aspect="equal", cmap="Purples"
            )
            plt.title("Prediction")

            # Overlay
            plt.subplot(1, 4, 4)
            plt.axis("off")
            plt.imshow(test_ground_truth_0_to_255[mid_plane], cmap="Greens")
            plt.imshow(test_prediction_0_to_255[mid_plane], alpha=0.5, cmap="Purples")
            plt.title(
                "Ground Truth and Prediction, Intersection over Union:"
                + str(round(pdResults.loc[file]["IoU"], 3))
            )
            plt.savefig(
                full_QC_model_path + "/Quality Control/QC_example_data.png",
                bbox_inches="tight",
                pad_inches=0,
            )

        # Make a pdf summary of the QC results
        qc_pdf_export()

        plt.show()


def function_25_cache(output_widget):
    global Source_QC_folder
    global Target_QC_folder
    global Automatic_number_of_tiles
    global n_tiles_Z
    global n_tiles_Y
    global n_tiles_X

    global n_tilesZYX
    global n_tilesZYX
    global Source_QC_folder_tif
    global lbl_cmap
    global Z
    global Z
    global n_channel
    global axis_norm
    global model
    global names
    global lenght_of_Z
    global img
    global labels
    global polygons
    global writer
    global filename_list
    global IoU_score_list
    global test_input
    global test_prediction
    global test_ground_truth_image
    global test_prediction_0_to_255
    global test_ground_truth_0_to_255
    global intersection
    global union
    global iou_score
    global pdResults
    global f
    global test_input
    global test_prediction
    global test_ground_truth_image
    global norm
    global Image_Z
    global mid_plane
    global test_prediction_0_to_255
    global test_ground_truth_0_to_255

    global show_QC_results

    cache_Source_QC_folder = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_25_Source_QC_folder"
    )
    if cache_Source_QC_folder != "":
        widget_Source_QC_folder.value = cache_Source_QC_folder

    cache_Target_QC_folder = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_25_Target_QC_folder"
    )
    if cache_Target_QC_folder != "":
        widget_Target_QC_folder.value = cache_Target_QC_folder

    cache_Automatic_number_of_tiles = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_25_Automatic_number_of_tiles"
    )
    if cache_Automatic_number_of_tiles != "":
        widget_Automatic_number_of_tiles.value = cache_Automatic_number_of_tiles

    cache_n_tiles_Z = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_25_n_tiles_Z"
    )
    if cache_n_tiles_Z != "":
        widget_n_tiles_Z.value = cache_n_tiles_Z

    cache_n_tiles_Y = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_25_n_tiles_Y"
    )
    if cache_n_tiles_Y != "":
        widget_n_tiles_Y.value = cache_n_tiles_Y

    cache_n_tiles_X = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_25_n_tiles_X"
    )
    if cache_n_tiles_X != "":
        widget_n_tiles_X.value = cache_n_tiles_X


button_function_25 = widgets.Button(description="Load and run")
cache_button_function_25 = widgets.Button(description="Load prev. settings")
output_function_25 = widgets.Output()
display(
    widgets.HBox((button_function_25, cache_button_function_25)), output_function_25
)


def aux_function_25(_):
    return function_25(output_function_25)


def aux_function_25_cache(_):
    return function_25_cache(output_function_25)


button_function_25.on_click(aux_function_25)
cache_button_function_25.on_click(aux_function_25_cache)
print("--------------------------------------------------------------")
print('^ Introduce the arguments and click "Load and run". ^')
print('^ Or first click "Load prev. settings" if any previous ^')
print('^ settings have been saved and then click "Load and run". ^')

# **5. Using the trained model**

---

<font size = 4>In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive.


If accessing the just trained model, you can skip the cell below and go directly to the next Section.

In [None]:
# Check and set default values for patch parameters
try:
    patch_size
except NameError:
    patch_size = 128
    print(f"patch_size not found, setting default value: {patch_size}")

try:
    patch_height
except NameError:
    patch_height = 16
    print(f"patch_height not found, setting default value: {patch_height}")

## **5.1. Generate prediction(s) from unseen dataset**
---

<font size = 4>The current trained model (from section 4.3) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Prediction_folder** folder as restored image stacks (ImageJ-compatible TIFF images).

<font size = 4>**`Data_folder`:** This folder should contains the images that you want to predict using the network that you trained.

<font size = 4>**`Result_folder`:** This folder will contain the predicted output ROI.

<font size = 4>**`Data_type`:** Please indicate if the images you want to predict are single images or stacks





### WARNING:
This section is designed to be used with **NON-NORMALIZED** images. If you have normalized your images, you should should disable the normalization step in the cell below.

In [None]:
# Run this cell to visualize the parameters and click the button to execute the code
internal_aux_initial_time = datetime.now()
clear_output()

display(
    Markdown(
        "### Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images."
    )
)
widget_Data_folder = widgets.Text(
    value="", style={"description_width": "initial"}, description="Data_folder:"
)
display(widget_Data_folder)
widget_Results_folder = widgets.Text(
    value="", style={"description_width": "initial"}, description="Results_folder:"
)
display(widget_Results_folder)
display(Markdown("### Do you want to use the current trained model?"))
widget_Use_the_current_trained_model = widgets.Checkbox(
    value=True,
    style={"description_width": "initial"},
    description="Use_the_current_trained_model:",
)
display(widget_Use_the_current_trained_model)
display(Markdown("### If not, please provide the path to the model folder:"))
widget_Prediction_model_folder = widgets.Text(
    value="",
    style={"description_width": "initial"},
    description="Prediction_model_folder:",
)
display(widget_Prediction_model_folder)
display(
    Markdown(
        "##### To analyse large image, your images need to be divided into tiles.  Each tile will then be processed independently and re-assembled to generate the final image. 'Automatic_number_of_tiles' will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. "
    )
)
widget_Automatic_number_of_tiles = widgets.Checkbox(
    value=False,
    style={"description_width": "initial"},
    description="Automatic_number_of_tiles:",
)
display(widget_Automatic_number_of_tiles)
display(
    Markdown(
        "##### If you get an Out of memory (OOM) error when using the 'Automatic_number_of_tiles' option, disable it and manually input the values to be used to process your images.  Progressively increases these numbers until the OOM error disappear."
    )
)
widget_n_tiles_Z = widgets.IntText(
    value=1, style={"description_width": "initial"}, description="n_tiles_Z:"
)
display(widget_n_tiles_Z)
widget_n_tiles_Y = widgets.IntText(
    value=1, style={"description_width": "initial"}, description="n_tiles_Y:"
)
display(widget_n_tiles_Y)
widget_n_tiles_X = widgets.IntText(
    value=1, style={"description_width": "initial"}, description="n_tiles_X:"
)
display(widget_n_tiles_X)


def function_28(output_widget):
    output_widget.clear_output()
    with output_widget:
        global Data_folder
        global Results_folder
        global Use_the_current_trained_model
        global Prediction_model_folder
        global Automatic_number_of_tiles
        global n_tiles_Z
        global n_tiles_Y
        global n_tiles_X

        global Prediction_model_name
        global Prediction_model_path
        global n_tilesZYX
        global n_tilesZYX
        global Prediction_model_name
        global Prediction_model_path
        global full_Prediction_model_path
        global W
        global R
        global Dataset
        global lbl_cmap
        global X
        global X
        global n_channel
        global axis_norm
        global model
        global names
        global FILEnames
        global m
        global lenght_of_X
        global img
        global labels
        global polygons
        global img
        global img
        global labels
        global z

        global show_QC_results

        from PIL import Image

        Data_folder = widget_Data_folder.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_28_Data_folder",
            widget_Data_folder.value,
        )
        # test_dataset = Data_folder

        Results_folder = widget_Results_folder.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_28_Results_folder",
            widget_Results_folder.value,
        )
        # results = results_folder

        # model name and path
        Use_the_current_trained_model = widget_Use_the_current_trained_model.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_28_Use_the_current_trained_model",
            widget_Use_the_current_trained_model.value,
        )

        Prediction_model_folder = widget_Prediction_model_folder.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_28_Prediction_model_folder",
            widget_Prediction_model_folder.value,
        )

        # Here we find the loaded model name and parent path
        Prediction_model_name = os.path.basename(Prediction_model_folder)
        Prediction_model_path = os.path.dirname(Prediction_model_folder)

        # Here we allow the user to choose the number of tile to be used when predicting the images

        Automatic_number_of_tiles = widget_Automatic_number_of_tiles.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_28_Automatic_number_of_tiles",
            widget_Automatic_number_of_tiles.value,
        )
        n_tiles_Z = widget_n_tiles_Z.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_28_n_tiles_Z",
            widget_n_tiles_Z.value,
        )
        n_tiles_Y = widget_n_tiles_Y.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_28_n_tiles_Y",
            widget_n_tiles_Y.value,
        )
        n_tiles_X = widget_n_tiles_X.value
        ipywidgets_edit_yaml(
            ipywidgets_edit_yaml_config_path,
            "function_28_n_tiles_X",
            widget_n_tiles_X.value,
        )

        if Automatic_number_of_tiles:
            n_tilesZYX = None

        if not (Automatic_number_of_tiles):
            n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)

        if Use_the_current_trained_model:
            print("Using current trained network")
            Prediction_model_name = model_name
            Prediction_model_path = model_path

        full_Prediction_model_path = (
            Prediction_model_path + "/" + Prediction_model_name + "/"
        )
        if os.path.exists(full_Prediction_model_path):
            print("The " + Prediction_model_name + " network will be used.")
        else:
            W = "\033[0m"  # white (normal)
            R = "\033[31m"  # red
            print(R + "!! WARNING: The chosen model does not exist !!" + W)
            print(
                "Please make sure you provide a valid model path and model name before proceeding further."
            )

        # single images
        # testDATA = test_dataset
        Dataset = Data_folder + "/*.tif"

        np.random.seed(16)
        lbl_cmap = random_label_cmap()
        X = sorted(glob(Dataset))

        n_channel = 1
        axis_norm = (0, 1, 2)  # normalize whole image stack jointly

        model = StarDist3D(
            None, name=Prediction_model_name, basedir=Prediction_model_path
        )

        names = [os.path.basename(f) for f in sorted(glob(Dataset))]

        # modify the names to suitable form: path_images/image_numberX.tif
        FILEnames = []
        for m in names:
            m = Results_folder + "/" + m
            FILEnames.append(m)

            # Predictions folder
        lenght_of_X = len(X)

        for i in range(lenght_of_X):
            img = imread(X[i])

            # ---------Normalize the image-----------------
            # Comment this line if you want to use the original image
            img = normalize(img, 1, 99.8, axis=axis_norm)
            # ----------------------------------------------

            n_tilesZYX = (
                int(img.shape[0] / patch_height),
                int(img.shape[1] / patch_size),
                int(img.shape[2] / patch_size),
            )

            labels, polygons = model.predict_instances(img, n_tiles=n_tilesZYX)
            # Save the predicted mask in the result folder
            os.chdir(Results_folder)
            save_tiff_imagej_compatible(FILEnames[i], labels, axes="ZYX")

        print("The mid-plane image is displayed below.")
        # ------------- For display ------------
        print("--------------------------------------------------------------")

        @interact
        def show_QC_results(file=os.listdir(Data_folder)):
            plt.figure(figsize=(13, 10))

            img = imread(os.path.join(Data_folder, file))
            img = normalize(img, 1, 99.8, axis=axis_norm)
            labels = imread(os.path.join(Results_folder, file))
            z = max(0, img.shape[0] // 2 - 5)

            plt.subplot(121)
            plt.imshow(
                (img if img.ndim == 3 else img[..., :3])[z], clim=(0, 1), cmap="gray"
            )
            plt.title("Raw image (XY slice)")
            plt.axis("off")
            plt.subplot(122)
            plt.imshow(
                (img if img.ndim == 3 else img[..., :3])[z], clim=(0, 1), cmap="gray"
            )
            plt.imshow(labels[z], cmap=lbl_cmap, alpha=0.5)
            plt.title("Image and predicted labels (XY slice)")
            plt.axis("off")

        plt.show()


def function_28_cache(output_widget):
    global Data_folder
    global Results_folder
    global Use_the_current_trained_model
    global Prediction_model_folder
    global Automatic_number_of_tiles
    global n_tiles_Z
    global n_tiles_Y
    global n_tiles_X

    global Prediction_model_name
    global Prediction_model_path
    global n_tilesZYX
    global n_tilesZYX
    global Prediction_model_name
    global Prediction_model_path
    global full_Prediction_model_path
    global W
    global R
    global Dataset
    global lbl_cmap
    global X
    global X
    global n_channel
    global axis_norm
    global model
    global names
    global FILEnames
    global m
    global lenght_of_X
    global img
    global labels
    global polygons
    global img
    global img
    global labels
    global z

    global show_QC_results

    cache_Data_folder = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_28_Data_folder"
    )
    if cache_Data_folder != "":
        widget_Data_folder.value = cache_Data_folder

    cache_Results_folder = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_28_Results_folder"
    )
    if cache_Results_folder != "":
        widget_Results_folder.value = cache_Results_folder

    cache_Use_the_current_trained_model = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_28_Use_the_current_trained_model"
    )
    if cache_Use_the_current_trained_model != "":
        widget_Use_the_current_trained_model.value = cache_Use_the_current_trained_model

    cache_Prediction_model_folder = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_28_Prediction_model_folder"
    )
    if cache_Prediction_model_folder != "":
        widget_Prediction_model_folder.value = cache_Prediction_model_folder

    cache_Automatic_number_of_tiles = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_28_Automatic_number_of_tiles"
    )
    if cache_Automatic_number_of_tiles != "":
        widget_Automatic_number_of_tiles.value = cache_Automatic_number_of_tiles

    cache_n_tiles_Z = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_28_n_tiles_Z"
    )
    if cache_n_tiles_Z != "":
        widget_n_tiles_Z.value = cache_n_tiles_Z

    cache_n_tiles_Y = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_28_n_tiles_Y"
    )
    if cache_n_tiles_Y != "":
        widget_n_tiles_Y.value = cache_n_tiles_Y

    cache_n_tiles_X = ipywidgets_read_yaml(
        ipywidgets_edit_yaml_config_path, "function_28_n_tiles_X"
    )
    if cache_n_tiles_X != "":
        widget_n_tiles_X.value = cache_n_tiles_X


button_function_28 = widgets.Button(description="Load and run")
cache_button_function_28 = widgets.Button(description="Load prev. settings")
output_function_28 = widgets.Output()
display(
    widgets.HBox((button_function_28, cache_button_function_28)), output_function_28
)


def aux_function_28(_):
    return function_28(output_function_28)


def aux_function_28_cache(_):
    return function_28_cache(output_function_28)


button_function_28.on_click(aux_function_28)
cache_button_function_28.on_click(aux_function_28_cache)
print("--------------------------------------------------------------")
print('^ Introduce the arguments and click "Load and run". ^')
print('^ Or first click "Load prev. settings" if any previous ^')
print('^ settings have been saved and then click "Load and run". ^')

# **6. Version log**
---
<font size = 4>**v1.15.3**:  

*    Replaced all absolute pathing with relative pathing

<font size = 4>**v1.15.2**:  

*  The way that the label are saved has been updated

<font size = 4>**v1.15.1**:  

*  Tensorflow now v2.8
*  `__future__` dependency is imported at the beginning

<font size = 4>**v1.13**:  

*  StarDist is now downgraded to v 0.6.2 to ensure compatibility with previously trained models.

*   This version now includes an automatic restart allowing to set the h5py library to v2.10.
*  The section 1 and 2 are now swapped for better export of *requirements.txt*.

*   This version also now includes built-in version check and the version log that you're reading now.








# **Thank you for using StarDist 3D!**
