# **Interactive segmentation**
---

<font size = 4>**Interactive segmentation** is a segmentation tool powered by deep learning and ImJoy that can be used to segment bioimages and was first published by [Ouyang *et al.* in 2021, on F1000R](https://f1000research.com/articles/10-142?s=09#ref-15).

<font size = 4>**The Original code** is freely available in GitHub:
https://github.com/imjoy-team/imjoy-interactive-segmentation

<font size = 4>**Please also cite this original paper when using or developing this notebook.**

<font size = 12>**!!Currently, this notebook only works with Google Chrome or Firefox!!**



# **1. Install Interactive segmentation**
---


In [None]:
Notebook_version = '1.13.1'
Network = 'Kaibu'


from builtins import any as b_any

def get_requirements_path():
    # Store requirements file in 'contents' directory
    current_dir = os.getcwd()
    dir_count = current_dir.count('/') - 1
    path = '../' * (dir_count) + 'requirements.txt'
    return path

def filter_files(file_list, filter_list):
    filtered_list = []
    for fname in file_list:
        if b_any(fname.split('==')[0] in s for s in filter_list):
            filtered_list.append(fname)
    return filtered_list

def build_requirements_file(before, after):
    path = get_requirements_path()

    # Exporting requirements.txt for local run
    !pip freeze > $path

    # Get minimum requirements file
    df = pd.read_csv(path, delimiter = "\n")
    mod_list = [m.split('.')[0] for m in after if not m in before]
    req_list_temp = df.values.tolist()
    req_list = [x[0] for x in req_list_temp]

    # Replace with package name and handle cases where import name is different to module name
    mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]
    mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list]
    filtered_list = filter_files(req_list, mod_replace_list)

    file=open(path,'w')
    for item in filtered_list:
        file.writelines(item + '\n')

    file.close()

import sys
before = [str(m) for m in sys.modules]

#@markdown ##Install Interactive segmentation
import time

# Start the clock to measure how long it takes
start = time.time()

# !pip install -U Werkzeug==1.0.1
# !pip install git+https://github.com/imjoy-team/imjoy-interactive-segmentation@support-export-geojson#egg=imjoy-interactive-segmentation
!pip install git+https://github.com/imjoy-team/imjoy-interactive-segmentation@master#egg=imjoy-interactive-trainer
!python3 -m ipykernel install --user --name imjoy-interactive-ml --display-name "ImJoy Interactive ML"

from imjoy_interactive_trainer.imjoy_plugin import start_interactive_segmentation
from imjoy_interactive_trainer.interactive_trainer import InteractiveTrainer
from imjoy_interactive_trainer.data_utils import download_example_dataset, mask_to_geojson
from imjoy_interactive_trainer.imgseg.geojson_utils import geojson_to_masks

import os
import glob
from shutil import copyfile, rmtree
from tifffile import imread, imsave
from ipywidgets import interact
import ipywidgets as widgets
from matplotlib import pyplot as plt
import cv2
from tqdm import tqdm
from skimage.util.shape import view_as_windows
from skimage import io

!pip install cellpose 
from cellpose import models
import numpy as np


import random
from zipfile import ZIP_DEFLATED
import csv
import pandas as pd

from numba import jit
from scipy.optimize import linear_sum_assignment
from collections import namedtuple

from tabulate import tabulate
from astropy.visualization import simple_norm
import matplotlib.pyplot as plt

from concurrent.futures import ThreadPoolExecutor


def PrepareDataAsPatches(Training_source, patch_width, patch_height, Data_tag, Training_target = None):

  # Here we assume that the Train and Test folders are already created
  patch_num = 0

  for file in tqdm(os.listdir(Training_source)):
    
    if os.path.isfile(os.path.join(Training_source, file)):
      img = io.imread(os.path.join(Training_source, file))
      _,this_ext = os.path.splitext(file)

      if len(img.shape) == 2:
        # Using view_as_windows with step size equal to the patch size to ensure there is no overlap
        patches_img = view_as_windows(img, (patch_width, patch_height), (patch_width, patch_height))
        patches_img = patches_img.reshape(patches_img.shape[0]*patches_img.shape[1], patch_width, patch_height)

      elif len(img.shape) == 3:
        # Using view_as_windows with step size equal to the patch size to ensure there is no overlap
        patches_img = view_as_windows(img, (patch_width, patch_height, img.shape[2]), (patch_width, patch_height, img.shape[2]))
        patches_img = patches_img.reshape(patches_img.shape[0]*patches_img.shape[1], patch_width, patch_height, img.shape[2])

      else:
        patches_img = []
        print('Data format currently unsupported.')


      if os.path.isfile(os.path.join(Training_target, file)):
        # print('Mask exists!')
        mask_exists = True
        mask = io.imread(os.path.join(Training_target, file))
        patches_mask = view_as_windows(mask, (patch_width, patch_height), (patch_width, patch_height))
        patches_mask = patches_mask.reshape(patches_mask.shape[0]*patches_mask.shape[1], patch_width, patch_height)
      else:
        mask_exists = False
        # print('Mask does not exist!')

      for i in range(patches_img.shape[0]):
        save_path = os.path.join(os.path.splitext(Training_source)[0], 'test','Patch_'+str(patch_num)+' - ' + os.path.splitext(os.path.basename(file))[0])
        os.mkdir(save_path)
        img_save_path = os.path.join(save_path, Data_tag)

        if (len(patches_img[i].shape) == 2):
          this_image = np.repeat(patches_img[i][:,:,np.newaxis], repeats=3, axis=2)  
        else:
          this_image = patches_img[i]

        # Convert to 8-bit to save as png
        this_image =(this_image/this_image.max()*255).astype('uint8')
        io.imsave(img_save_path, this_image, check_contrast = False)

        # Save raw images patches, preserving format and bit depth
        img_save_path_raw = os.path.join(save_path, 'Raw_data'+this_ext)
        io.imsave(img_save_path_raw, patches_img[i], check_contrast = False)

        if mask_exists:
          with open(os.path.join(save_path, "annotation.json"), "w", encoding="utf-8") as f:
            f.write(mask_to_geojson(patches_mask[i],simplify_tol=None))

        patch_num += 1

def get_image_list(Folder_path, extension_list = ['*.jpg', '*.tif', '*.png']):
  image_list = []
  for ext in extension_list:
    image_list = image_list + glob.glob(Folder_path+"/"+ext)

  n_files = len(image_list)
  print('Number of files: '+str(n_files))

  filenames_list = []
  for img_name in image_list:
    filenames_list.append(os.path.basename(img_name))

  return image_list, filenames_list

# Colors for the warning messages
class bcolors:
  WARNING = '\033[31m'
  NORMAL = '\033[0m'  # white (normal)

# Check if this is the latest version of the notebook
# Latest_notebook_version = pd.read_csv("https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv")

# print('Notebook version: '+Notebook_version[0])

# strlist = Notebook_version[0].split('.')
# Notebook_version_main = strlist[0]+'.'+strlist[1]

# if Notebook_version_main == Latest_notebook_version.columns:
#   print("This notebook is up-to-date.")
# else:
#   print(bcolors.WARNING +"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki"+ bcolors.NORMAL)

All_notebook_versions = pd.read_csv("https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv", dtype=str)
print('Notebook version: '+Notebook_version)
Latest_Notebook_version = All_notebook_versions[All_notebook_versions["Notebook"] == Network]['Version'].iloc[0]
print('Latest notebook version: '+Latest_Notebook_version)
if Notebook_version == Latest_Notebook_version:
  print("This notebook is up-to-date.")
else:
  print(bcolors.WARNING +"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki")

# Build requirements file for local run
after = [str(m) for m in sys.modules]
build_requirements_file(before, after)

# ---------------------------- Display ----------------------------
# Displaying the time elapsed for installation
dt = time.time() - start
minutes, seconds = divmod(dt, 60) 
hours, minutes = divmod(minutes, 60) 
print("Time elapsed:",hours, "hour(s)",minutes,"min(s)",round(seconds,1),"sec(s)")


print("-----------")
print("Interactive segmentation installed.")

# **2. Complete the Colab session**
---


## **2.1. Check for GPU access**
---

By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:

<font size = 4>Go to **Runtime -> Change the Runtime type**

<font size = 4>**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*

<font size = 4>**Accelerator: GPU** *(Graphics processing unit)*


In [None]:
#@markdown ##Run this cell to check if you have GPU access

import tensorflow as tf
if tf.test.gpu_device_name()=='':
  print('You do not have GPU access.') 
  print('Did you change your runtime ?') 
  print('If the runtime setting is correct then Google did not allocate a GPU for your session')
  print('Expect slow performance. To access GPU try reconnecting later')

else:
  print('You have GPU access')
  !nvidia-smi

## **2.2. Mount your Google Drive**
---
<font size = 4> To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.

<font size = 4> Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. 

<font size = 4> Once this is done, your data are available in the **Files** tab on the top left of notebook.

In [None]:
#@markdown ##Play the cell to connect your Google Drive to Colab

#@markdown * Click on the URL. 

#@markdown * Sign in your Google Account. 

#@markdown * Copy the authorization code. 

#@markdown * Enter the authorization code. 

#@markdown * Click on "Files" site on the right. Refresh the site. Your Google Drive folder should now be available here as "drive". 

# mount user's Google Drive to Google Colab.
from google.colab import drive
drive.mount('/content/gdrive')

# **3. Select your parameters and paths**
---

## **3.1. Set and prepare dataset**
---

<font size = 4>**WARNING: Currently this notebook only builds 'Grayscale' Cellpose models. So please provide only grayscale equivalent dataset. WARNING.**

<font size = 4>**`Data_folder:`:** This is the path to the data to use for interactive annotation. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below. 

<font size = 4>**`Mask_folder:`:** **(OPTIONAL)** This is the path to the corresponding masks in case some of the files contain pre-annotated labelled masks. The masks need to have the same name as their corresponding source files. Not all files need to have a mask associated with it. If no masks are available, leave path empty.

<font size = 4>**`Use_patch:`:** This option splits the data available into patches of **`patch_size`** x **`patch_size`**. This allows to make all data consistent and formatted. We recommend to always use this option for stability. 

<font size = 4>**`Reset_data:`:** Resetting the data will empty the training data folder and remove all the annotations available from previous uses.

<font size = 4>**`Use_example_data:`:** This will download and use the example data provided by [Ouyang *et al.* in 2021, on F1000R](https://f1000research.com/articles/10-142?s=09#ref-15).


In [None]:
#@markdown ###**Prepare data**

Data_folder = "" #@param {type:"string"}
Mask_folder = "" #@param {type:"string"}

#@markdown ###Split the data in small non-verlapping patches?
Use_patches = True #@param {type:"boolean"}

patch_size =  256#@param {type:"integer"}

#@markdown ###Reset the data? (**!annotations will be lost!**)
Reset_data = False #@param {type:"boolean"}
#@markdown ###Otherwise, use example data
Use_example_dataset = False #@param {type:"boolean"}


if Use_example_dataset:
  Data_folder = "/content/data/hpa_dataset_v2"
  download_example_dataset()
  Data_split = ["microtubules.png", "er.png", "nuclei.png"]
  channels=[2, 3]


else:
  Data_tag = "data.png" #Kaibu works best with PNGs!
  Data_split = [Data_tag]
  channels=[0, 0] # grayscale images without nuclei channel

  if (Reset_data) and (os.path.exists(os.path.join(Data_folder, "train"))):
    rmtree(os.path.join(Data_folder, "train"))

  if (Reset_data) and (os.path.exists(os.path.join(Data_folder, "test"))):
    rmtree(os.path.join(Data_folder, "test"))

  if (os.path.exists(os.path.join(Data_folder, "train"))) and (os.path.exists(os.path.join(Data_folder, "test"))):
    print("Kaibu data already exist. Starting from these annotations!")
  else:
    print("Creating new folders!")

    os.mkdir(os.path.join(Data_folder, "train"))
    os.mkdir(os.path.join(Data_folder, "test"))

    if Use_patches:
      PrepareDataAsPatches(Data_folder, patch_size,patch_size, Data_tag, Mask_folder)
    else:
      image_list, _ = get_image_list(Data_folder)
      # jpeg_image_list = glob.glob(Data_folder+"/*.jpg")
      n_files = len(image_list)
      print("Total number of files: "+str(n_files))

      for image in image_list:
        save_path = os.path.join(Data_folder, "test", os.path.splitext(os.path.basename(image))[0])
        os.mkdir(save_path)
        copyfile(image, save_path+"/"+Data_tag)

        if os.path.isfile(os.path.join(Mask_folder, os.path.basename(image))):
          mask = io.imread(os.path.join(Mask_folder, os.path.basename(image)))
          with open(os.path.join(save_path, "annotation.json"), "w", encoding="utf-8") as f:
            f.write(mask_to_geojson(mask,simplify_tol=None))


extension_list = ['*.jpg', '*.tif', '*.png']
image_list, filenames_list = get_image_list(Data_folder, extension_list)


if len(filenames_list) > 0:
  # ------------- For display ------------
  print('--------------------------------------------------------------')
  @interact
  def show_example_data(name = filenames_list):

    plt.figure(figsize=(13,10))
    img = io.imread(os.path.join(Data_folder, name))

    plt.imshow(img, cmap='gray')
    plt.title('Source image ('+str(img.shape[0])+'x'+str(img.shape[1])+')')
    plt.axis('off')




print("------")
print("Data prepared for interactive segmentation.")








## **3.2. Prepare the Cellpose model**
---

<font size = 4>**`Model_folder:`:** Enter the path where your model will be saved once trained (for instance your result folder).

<font size = 4>**`Model_name:`:** Name of the model, the notebook will create a folder with this name within which the model will be saved.

<font size = 4>**`default_diameter:`:** Indicate the diameter of the objects (cells or Nuclei) you want to segment. If you input "0", this parameter will be estimated automatically for each of your images.

<font size = 4>**`model_type:`:** This is the Cellpose model that will be loaded initially and from which it will train from. This will allow to run reasonnable predictions even with no additional training data.



In [None]:
#@markdown ###**Prepare model**

Model_folder = "" #@param {type:"string"}
Model_name = "" #@param {type:"string"}

if (os.path.exists(os.path.join(Model_folder, Model_name))):
  print(bcolors.WARNING +"Model folder already exists and will be deleted at next step."+bcolors.NORMAL)


default_diameter =  0 #@param {type:"number"}
model_type = "default" #@param ["cyto", "nuclei", "default", "none","Own model"]

#@markdown ###**If using your own model, please select the path to the model**
own_model_path = "" #@param {type:"string"}



resume = True
pretrained_model = None


if (model_type == "default"):
  model_type = 'cyto'

if (model_type == "none"):
  model_type = 'cyto'
  resume = False

if (model_type == "Own model"):
  model_type = None
  pretrained_model = own_model_path



if (default_diameter == 0):
  default_diameter = None


print("------")
# print(model_type)
# print(pretrained_model)
print("Model prepared.")


#**4. Interactive segmentation**
---


## **4.1. Run interactive segmentation interface**
---

<font size = 4> This will start the interactive segmentation interface using ImJoy and Kaibu. 

*   Get an image
*   Predict on that image (using the pretrained model selected above)
*   Edit the segmentations using the Selection and Draw tools
*   You can save the annotations at any time
*   When you're happy with the annotations, you can send the data for training
*   You can then start the training
*   Get a new image and annotate as above, send for training and repeat

<font size = 4> The training will be running in the background as you annotate and send more training data. This will both generate high quality training data while building an increasingly good model.

<font size = 4> The Kaibu interface can be made fullscreen by clicking on the three dots on the top right of the cell, and select the Full screen option. Then, it needs to be minimized and maximised again.


In [None]:
# @markdown #Start interactive segmentation interface

# Restart the trainer if necessary
instance_exist = True
try:
  trainer = InteractiveTrainer.get_instance()
except:
  instance_exist = False

if instance_exist:
  print("Trainer already exists. Restarting trainer!")
  trainer.stop()


if (os.path.exists(os.path.join(Model_folder, Model_name))):
  print('Deleting pre-existing model folder...')
  rmtree(os.path.join(Model_folder, Model_name))

os.mkdir(os.path.join(Model_folder, Model_name))

model_config = dict(type="cellpose",
                     model_dir=os.path.join(Model_folder, Model_name),
                     use_gpu=True,
                     channels=[0, 0],
                     style_on=0,
                     batch_size=1,
                     default_diameter = default_diameter,
                     pretrained_model = pretrained_model,
                     model_type = model_type,
                     resume = resume)

start_interactive_segmentation(model_config,
                               Data_folder,
                               Data_split,
                               object_name="cell",
                               scale_factor=1.0,
                               restore_test_annotation=True)


# start_interactive_segmentation(model_config,
#                                "/content/DATA TEMP",
#                                ["data_image.tif"],
#                                object_name="cell",
#                                scale_factor=1.0,
#                                restore_test_annotation=True)



## **4.2. Create masks from annotations**
---

<font size = 4> This cell will allow you to create and visualise instance segmentation masks. It will be created from the annotations made from the interface and will be saved into a folder called **`Paired training dataset`** in your data folder (**`Data_folder`**). This data can be used to train another segmentation model if necessary.

In [None]:
#@markdown ##Create and check annotated training images
if (os.path.exists(os.path.join(Data_folder,'Paired training dataset'))):
  rmtree(os.path.join(Data_folder,'Paired training dataset'))

os.mkdir(os.path.join(Data_folder,'Paired training dataset'))
os.mkdir(os.path.join(Data_folder,'Paired training dataset','Images'))
os.mkdir(os.path.join(Data_folder,'Paired training dataset','Masks'))


dir_list = os.listdir(os.path.join(Data_folder, "train"))
# _, ext = os.path.splitext(Data_tag)

for dir in dir_list:
  annotation_file = os.path.join(Data_folder, "train", dir, "annotation.json") 
  mask_dict = geojson_to_masks(annotation_file, mask_types=["labels"]) 
  labels = mask_dict["labels"]
  imsave(os.path.join(Data_folder, "train", dir, "label.tif"), labels)

  imsave(os.path.join(Data_folder,'Paired training dataset','Masks', dir+".tif"), labels)


  file_list = os.listdir(os.path.join(Data_folder, "train", dir))
  for file in file_list:
    filename, this_ext = os.path.splitext(file)
    if filename == 'Raw_data':
        copyfile(os.path.join(Data_folder, "train", dir, file), os.path.join(Data_folder,'Paired training dataset','Images', dir+this_ext))

  # raw_data_tag = glob.glob(dir+"/Raw_data.*")
  # print(raw_data_tag)
  # copyfile(os.path.join(Data_folder, "train", dir, Data_tag), os.path.join(Data_folder,'Paired training dataset','Images', dir+ext))




# ------------- For display ------------
print('--------------------------------------------------------------')
@interact
def show_labels(dir=dir_list):
  plt.figure(figsize=(13,10))

  imgSource = cv2.imread(os.path.join(Data_folder, "train", dir, Data_tag))
  imgLabel = imread(os.path.join(Data_folder, "train", dir, "label.tif"))

  plt.subplot(121)
  plt.imshow(imgSource, cmap='gray', interpolation='nearest')
  plt.title('Source image')
  plt.axis('off')
  plt.subplot(122)
  plt.imshow(imgLabel, cmap='nipy_spectral', interpolation='nearest')
  plt.title('Label')
  plt.axis('off');


## **4.3. Stop training**
---

<font size = 4> Once training has started, the training will carry on until stopped. Here, the training can be stopped. This will automatically create the final model, which can be used for Quality Control (Section 5 below) and for predictions (Section 6 below).


In [None]:
#@markdown ##Stop training

# Stop the trainer if it exists
instance_exist = True
try:
  trainer = InteractiveTrainer.get_instance()
except:
  instance_exist = False

print('-------------')
if instance_exist:
  print("Trainer stopped.")
  trainer.stop()
else:
  print("No trainers currently running.")



# **5. Evaluate your model**
---

<font size = 4>This section allows the user to perform important quality checks on the validity and generalisability of the trained model.  


<font size = 4>**We highly recommend to perform quality control on all newly trained models.**




In [None]:
# model name and path
#@markdown ###Do you want to assess the model you just trained ?
Use_the_current_trained_model = True #@param {type:"boolean"}

#@markdown ###If not, indicate which model you want to assess:

QC_model_path = "" #@param {type:"string"}


if Use_the_current_trained_model :

  QC_model_path = Model_folder+"/"+Model_name+"/final"

  #model = models.CellposeModel(gpu=True, pretrained_model=QC_model_path, torch=False, diam_mean=30.0, net_avg=True, device=None, residual_on=True, style_on=True, concatenation=False)
  model = models.CellposeModel(gpu=True, pretrained_model=QC_model_path, torch=True, diam_mean=30.0, net_avg=True, device=None, residual_on=True, style_on=True, concatenation=False)

  QC_model_folder = os.path.dirname(QC_model_path)
  QC_model_name = os.path.basename(QC_model_folder)
  Saving_path = QC_model_folder

  print("The "+str(QC_model_name)+" model will be evaluated")

else:

  if os.path.exists(QC_model_path):
    model = models.CellposeModel(gpu=True, pretrained_model=QC_model_path, torch=True, diam_mean=30.0, net_avg=True, device=None, residual_on=True, style_on=True, concatenation=False)
    
    QC_model_folder = os.path.dirname(QC_model_path)
    Saving_path = QC_model_folder
    QC_model_name = os.path.basename(QC_model_folder)
    print("The "+str(QC_model_name)+" model will be evaluated")
    
  else:  
    print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!')
    print('Please make sure you provide a valid model path before proceeding further.')



# Here we load the def that perform the QC, code taken from StarDist  https://github.com/mpicbg-csbd/stardist/blob/master/stardist/matching.py



matching_criteria = dict()

def label_are_sequential(y):
    """ returns true if y has only sequential labels from 1... """
    labels = np.unique(y)
    return (set(labels)-{0}) == set(range(1,1+labels.max()))


def is_array_of_integers(y):
    return isinstance(y,np.ndarray) and np.issubdtype(y.dtype, np.integer)


def _check_label_array(y, name=None, check_sequential=False):
    err = ValueError("{label} must be an array of {integers}.".format(
        label = 'labels' if name is None else name,
        integers = ('sequential ' if check_sequential else '') + 'non-negative integers',
    ))
    is_array_of_integers(y) or print("An error occured")
    if check_sequential:
        label_are_sequential(y) or print("An error occured")
    else:
        y.min() >= 0 or print("An error occured")
    return True


def label_overlap(x, y, check=True):
    if check:
        _check_label_array(x,'x',True)
        _check_label_array(y,'y',True)
        x.shape == y.shape or _raise(ValueError("x and y must have the same shape"))
    return _label_overlap(x, y)

@jit(nopython=True)
def _label_overlap(x, y):
    x = x.ravel()
    y = y.ravel()
    overlap = np.zeros((1+x.max(),1+y.max()), dtype=np.uint)
    for i in range(len(x)):
        overlap[x[i],y[i]] += 1
    return overlap


def intersection_over_union(overlap):
    _check_label_array(overlap,'overlap')
    if np.sum(overlap) == 0:
        return overlap
    n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)
    n_pixels_true = np.sum(overlap, axis=1, keepdims=True)
    return overlap / (n_pixels_pred + n_pixels_true - overlap)

matching_criteria['iou'] = intersection_over_union


def intersection_over_true(overlap):
    _check_label_array(overlap,'overlap')
    if np.sum(overlap) == 0:
        return overlap
    n_pixels_true = np.sum(overlap, axis=1, keepdims=True)
    return overlap / n_pixels_true

matching_criteria['iot'] = intersection_over_true


def intersection_over_pred(overlap):
    _check_label_array(overlap,'overlap')
    if np.sum(overlap) == 0:
        return overlap
    n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)
    return overlap / n_pixels_pred

matching_criteria['iop'] = intersection_over_pred


def precision(tp,fp,fn):
    return tp/(tp+fp) if tp > 0 else 0
def recall(tp,fp,fn):
    return tp/(tp+fn) if tp > 0 else 0
def accuracy(tp,fp,fn):
    # also known as "average precision" (?)
    # -> https://www.kaggle.com/c/data-science-bowl-2018#evaluation
    return tp/(tp+fp+fn) if tp > 0 else 0
def f1(tp,fp,fn):
    # also known as "dice coefficient"
    return (2*tp)/(2*tp+fp+fn) if tp > 0 else 0


def _safe_divide(x,y):
    return x/y if y>0 else 0.0

def matching(y_true, y_pred, thresh=0.5, criterion='iou', report_matches=False):
    """Calculate detection/instance segmentation metrics between ground truth and predicted label images.
    Currently, the following metrics are implemented:
    'fp', 'tp', 'fn', 'precision', 'recall', 'accuracy', 'f1', 'criterion', 'thresh', 'n_true', 'n_pred', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'
    Corresponding objects of y_true and y_pred are counted as true positives (tp), false positives (fp), and false negatives (fn)
    whether their intersection over union (IoU) >= thresh (for criterion='iou', which can be changed)
    * mean_matched_score is the mean IoUs of matched true positives
    * mean_true_score is the mean IoUs of matched true positives but normalized by the total number of GT objects
    * panoptic_quality defined as in Eq. 1 of Kirillov et al. "Panoptic Segmentation", CVPR 2019
    Parameters
    ----------
    y_true: ndarray
        ground truth label image (integer valued)
        predicted label image (integer valued)
    thresh: float
        threshold for matching criterion (default 0.5)
    criterion: string
        matching criterion (default IoU)
    report_matches: bool
        if True, additionally calculate matched_pairs and matched_scores (note, that this returns even gt-pred pairs whose scores are below  'thresh')
    Returns
    -------
    Matching object with different metrics as attributes
    Examples
    --------
    >>> y_true = np.zeros((100,100), np.uint16)
    >>> y_true[10:20,10:20] = 1
    >>> y_pred = np.roll(y_true,5,axis = 0)
    >>> stats = matching(y_true, y_pred)
    >>> print(stats)
    Matching(criterion='iou', thresh=0.5, fp=1, tp=0, fn=1, precision=0, recall=0, accuracy=0, f1=0, n_true=1, n_pred=1, mean_true_score=0.0, mean_matched_score=0.0, panoptic_quality=0.0)
    """
    _check_label_array(y_true,'y_true')
    _check_label_array(y_pred,'y_pred')
    y_true.shape == y_pred.shape or _raise(ValueError("y_true ({y_true.shape}) and y_pred ({y_pred.shape}) have different shapes".format(y_true=y_true, y_pred=y_pred)))
    criterion in matching_criteria or _raise(ValueError("Matching criterion '%s' not supported." % criterion))
    if thresh is None: thresh = 0
    thresh = float(thresh) if np.isscalar(thresh) else map(float,thresh)

    y_true, _, map_rev_true = relabel_sequential(y_true)
    y_pred, _, map_rev_pred = relabel_sequential(y_pred)

    overlap = label_overlap(y_true, y_pred, check=False)
    scores = matching_criteria[criterion](overlap)
    assert 0 <= np.min(scores) <= np.max(scores) <= 1

    # ignoring background
    scores = scores[1:,1:]
    n_true, n_pred = scores.shape
    n_matched = min(n_true, n_pred)

    def _single(thr):
        not_trivial = n_matched > 0 and np.any(scores >= thr)
        if not_trivial:
            # compute optimal matching with scores as tie-breaker
            costs = -(scores >= thr).astype(float) - scores / (2*n_matched)
            true_ind, pred_ind = linear_sum_assignment(costs)
            assert n_matched == len(true_ind) == len(pred_ind)
            match_ok = scores[true_ind,pred_ind] >= thr
            tp = np.count_nonzero(match_ok)
        else:
            tp = 0
        fp = n_pred - tp
        fn = n_true - tp
        # assert tp+fp == n_pred
        # assert tp+fn == n_true

        # the score sum over all matched objects (tp)
        sum_matched_score = np.sum(scores[true_ind,pred_ind][match_ok]) if not_trivial else 0.0

        # the score average over all matched objects (tp)
        mean_matched_score = _safe_divide(sum_matched_score, tp)
        # the score average over all gt/true objects
        mean_true_score    = _safe_divide(sum_matched_score, n_true)
        panoptic_quality   = _safe_divide(sum_matched_score, tp+fp/2+fn/2)

        stats_dict = dict (
            criterion          = criterion,
            thresh             = thr,
            fp                 = fp,
            tp                 = tp,
            fn                 = fn,
            precision          = precision(tp,fp,fn),
            recall             = recall(tp,fp,fn),
            accuracy           = accuracy(tp,fp,fn),
            f1                 = f1(tp,fp,fn),
            n_true             = n_true,
            n_pred             = n_pred,
            mean_true_score    = mean_true_score,
            mean_matched_score = mean_matched_score,
            panoptic_quality   = panoptic_quality,
        )
        if bool(report_matches):
            if not_trivial:
                stats_dict.update (
                    # int() to be json serializable
                    matched_pairs  = tuple((int(map_rev_true[i]),int(map_rev_pred[j])) for i,j in zip(1+true_ind,1+pred_ind)),
                    matched_scores = tuple(scores[true_ind,pred_ind]),
                    matched_tps    = tuple(map(int,np.flatnonzero(match_ok))),
                )
            else:
                stats_dict.update (
                    matched_pairs  = (),
                    matched_scores = (),
                    matched_tps    = (),
                )
        return namedtuple('Matching',stats_dict.keys())(*stats_dict.values())

    return _single(thresh) if np.isscalar(thresh) else tuple(map(_single,thresh))



def matching_dataset(y_true, y_pred, thresh=0.5, criterion='iou', by_image=False, show_progress=True, parallel=False):
    """matching metrics for list of images, see `stardist.matching.matching`
    """
    len(y_true) == len(y_pred) or _raise(ValueError("y_true and y_pred must have the same length."))
    return matching_dataset_lazy (
        tuple(zip(y_true,y_pred)), thresh=thresh, criterion=criterion, by_image=by_image, show_progress=show_progress, parallel=parallel,
    )



def matching_dataset_lazy(y_gen, thresh=0.5, criterion='iou', by_image=False, show_progress=True, parallel=False):

    expected_keys = set(('fp', 'tp', 'fn', 'precision', 'recall', 'accuracy', 'f1', 'criterion', 'thresh', 'n_true', 'n_pred', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'))

    single_thresh = False
    if np.isscalar(thresh):
        single_thresh = True
        thresh = (thresh,)

    tqdm_kwargs = {}
    tqdm_kwargs['disable'] = not bool(show_progress)
    if int(show_progress) > 1:
        tqdm_kwargs['total'] = int(show_progress)

    # compute matching stats for every pair of label images
    if parallel:
        fn = lambda pair: matching(*pair, thresh=thresh, criterion=criterion, report_matches=False)
        with ThreadPoolExecutor() as pool:
            stats_all = tuple(pool.map(fn, tqdm(y_gen,**tqdm_kwargs)))
    else:
        stats_all = tuple (
            matching(y_t, y_p, thresh=thresh, criterion=criterion, report_matches=False)
            for y_t,y_p in tqdm(y_gen,**tqdm_kwargs)
        )

    # accumulate results over all images for each threshold separately
    n_images, n_threshs = len(stats_all), len(thresh)
    accumulate = [{} for _ in range(n_threshs)]
    for stats in stats_all:
        for i,s in enumerate(stats):
            acc = accumulate[i]
            for k,v in s._asdict().items():
                if k == 'mean_true_score' and not bool(by_image):
                    # convert mean_true_score to "sum_matched_score"
                    acc[k] = acc.setdefault(k,0) + v * s.n_true
                else:
                    try:
                        acc[k] = acc.setdefault(k,0) + v
                    except TypeError:
                        pass

    # normalize/compute 'precision', 'recall', 'accuracy', 'f1'
    for thr,acc in zip(thresh,accumulate):
        set(acc.keys()) == expected_keys or _raise(ValueError("unexpected keys"))
        acc['criterion'] = criterion
        acc['thresh'] = thr
        acc['by_image'] = bool(by_image)
        if bool(by_image):
            for k in ('precision', 'recall', 'accuracy', 'f1', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'):
                acc[k] /= n_images
        else:
            tp, fp, fn, n_true = acc['tp'], acc['fp'], acc['fn'], acc['n_true']
            sum_matched_score = acc['mean_true_score']

            mean_matched_score = _safe_divide(sum_matched_score, tp)
            mean_true_score    = _safe_divide(sum_matched_score, n_true)
            panoptic_quality   = _safe_divide(sum_matched_score, tp+fp/2+fn/2)

            acc.update(
                precision          = precision(tp,fp,fn),
                recall             = recall(tp,fp,fn),
                accuracy           = accuracy(tp,fp,fn),
                f1                 = f1(tp,fp,fn),
                mean_true_score    = mean_true_score,
                mean_matched_score = mean_matched_score,
                panoptic_quality   = panoptic_quality,
            )

    accumulate = tuple(namedtuple('DatasetMatching',acc.keys())(*acc.values()) for acc in accumulate)
    return accumulate[0] if single_thresh else accumulate



# copied from scikit-image master for now (remove when part of a release)
def relabel_sequential(label_field, offset=1):
    """Relabel arbitrary labels to {`offset`, ... `offset` + number_of_labels}.
    This function also returns the forward map (mapping the original labels to
    the reduced labels) and the inverse map (mapping the reduced labels back
    to the original ones).
    Parameters
    ----------
    label_field : numpy array of int, arbitrary shape
        An array of labels, which must be non-negative integers.
    offset : int, optional
        The return labels will start at `offset`, which should be
        strictly positive.
    Returns
    -------
    relabeled : numpy array of int, same shape as `label_field`
        The input label field with labels mapped to
        {offset, ..., number_of_labels + offset - 1}.
        The data type will be the same as `label_field`, except when
        offset + number_of_labels causes overflow of the current data type.
    forward_map : numpy array of int, shape ``(label_field.max() + 1,)``
        The map from the original label space to the returned label
        space. Can be used to re-apply the same mapping. See examples
        for usage. The data type will be the same as `relabeled`.
    inverse_map : 1D numpy array of int, of length offset + number of labels
        The map from the new label space to the original space. This
        can be used to reconstruct the original label field from the
        relabeled one. The data type will be the same as `relabeled`.
    Notes
    -----
    The label 0 is assumed to denote the background and is never remapped.
    The forward map can be extremely big for some inputs, since its
    length is given by the maximum of the label field. However, in most
    situations, ``label_field.max()`` is much smaller than
    ``label_field.size``, and in these cases the forward map is
    guaranteed to be smaller than either the input or output images.
    Examples
    --------
    >>> from skimage.segmentation import relabel_sequential
    >>> label_field = np.array([1, 1, 5, 5, 8, 99, 42])
    >>> relab, fw, inv = relabel_sequential(label_field)
    >>> relab
    array([1, 1, 2, 2, 3, 5, 4])
    >>> fw
    array([0, 1, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
           0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0,
           0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
           0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
           0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5])
    >>> inv
    array([ 0,  1,  5,  8, 42, 99])
    >>> (fw[label_field] == relab).all()
    True
    >>> (inv[relab] == label_field).all()
    True
    >>> relab, fw, inv = relabel_sequential(label_field, offset=5)
    >>> relab
    array([5, 5, 6, 6, 7, 9, 8])
    """
    offset = int(offset)
    if offset <= 0:
        raise ValueError("Offset must be strictly positive.")
    if np.min(label_field) < 0:
        raise ValueError("Cannot relabel array that contains negative values.")
    max_label = int(label_field.max()) # Ensure max_label is an integer
    if not np.issubdtype(label_field.dtype, np.integer):
        new_type = np.min_scalar_type(max_label)
        label_field = label_field.astype(new_type)
    labels = np.unique(label_field)
    labels0 = labels[labels != 0]
    new_max_label = offset - 1 + len(labels0)
    new_labels0 = np.arange(offset, new_max_label + 1)
    output_type = label_field.dtype
    required_type = np.min_scalar_type(new_max_label)
    if np.dtype(required_type).itemsize > np.dtype(label_field.dtype).itemsize:
        output_type = required_type
    forward_map = np.zeros(max_label + 1, dtype=output_type)
    forward_map[labels0] = new_labels0
    inverse_map = np.zeros(new_max_label + 1, dtype=output_type)
    inverse_map[offset:] = labels0
    relabeled = forward_map[label_field]
    return relabeled, forward_map, inverse_map





## **5.1. Inspection of the loss function**
---

<font size = 4>It is good practice to evaluate the training progress by looking at the training loss over training epochs. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*

<font size = 4>**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.

<font size = 4>During training values should decrease before reaching a minimal value which does not decrease further even after more training.

<font size = 4>Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required.




In [None]:
#@markdown ##Plot loss function

# Stop the trainer if it exists
instance_exist = True
try:
  trainer = InteractiveTrainer.get_instance()
except:
  instance_exist = False

if Use_the_current_trained_model and instance_exist:

  trainer = InteractiveTrainer.get_instance()
  reports = trainer.get_reports()
  
  loss = [report['loss'] for report in reports]

  plt.figure(figsize=(15,10))

  plt.subplot(2,1,1)
  plt.plot(loss, label='Training loss')
  plt.title('Training loss vs. epoch number (linear scale)')
  plt.ylabel('Loss')
  plt.xlabel('Epoch number')


  plt.subplot(2,1,2)
  plt.semilogy(loss, label='Training loss')
  plt.title('Training loss vs. epoch number (log scale)')
  plt.ylabel('Loss')
  plt.xlabel('Epoch number')
  # plt.savefig(os.path.join(QC_model_path,'Quality Control/lossCurvePlots.png'), bbox_inches='tight', pad_inches=0)
  plt.show()

else:
  print(bcolors.WARNING+"Loss curves can currently only be obtained from a currently trained model."+bcolors.NORMAL)

## **5.2. Error mapping and quality metrics estimation**
---
<font size = 4>This section will calculate the Intersection over Union score for all the images provided in the `Source_QC_folder` and `Target_QC_folder` ! The result for one of the image will also be displayed.

<font size = 4>The **Intersection over Union** (IoU) metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. 

<font size = 4>Here, the IoU is both calculated over the whole image and on a per-object basis. The value displayed below is the IoU value calculated over the entire image. The IoU value calculated on a per-object basis is used to calculate the other metrics displayed.

<font size = 4>“n_true” refers to the number of objects present in the ground truth image. “n_pred” refers to the number of objects present in the predicted image. 

<font size = 4>When a segmented object has an IoU value above 0.5 (compared to the corresponding ground truth), it is then considered a true positive. The number of “**true positives**” is available in the table below. The number of “false positive” is then defined as  “**false positive**” = “n_pred” - “true positive”. The number of “false negative” is defined as “false negative” = “n_true” - “true positive”.

<font size = 4>The mean_matched_score is the mean IoUs of matched true positives. The mean_true_score is the mean IoUs of matched true positives but normalized by the total number of ground truth objects. The panoptic_quality is calculated as described by [Kirillov et al. 2019](https://arxiv.org/abs/1801.00868).

<font size = 4>For more information about the other metric displayed, please consult the SI of the paper describing ZeroCostDL4Mic.

<font size = 4> The results can be found in the "*Quality Control*" folder which is located inside your `model_folder`.

In [None]:
#@markdown ##Choose the folders that contain your Quality Control dataset

Source_QC_folder = "" #@param{type:"string"}
Target_QC_folder = "" #@param{type:"string"}



#@markdown ### Segmentation parameters:
Object_diameter =  0#@param {type:"number"}

Flow_threshold = 0.4 #@param {type:"slider", min:0.1, max:1.1, step:0.1}
Cell_probability_threshold=0 #@param {type:"slider", min:-6, max:6, step:1}

if Object_diameter is 0:
  Object_diameter = None
  print("The cell size will be estimated automatically for each image")


# Find the number of channel in the input image

random_choice = random.choice(os.listdir(Source_QC_folder))
x = io.imread(Source_QC_folder+"/"+random_choice)
n_channel = 1 if x.ndim == 2 else x.shape[-1]


channels=[0,0]
QC_model_folder = os.path.join(Model_folder,Model_name)
QC_model_path = os.path.join(QC_model_folder, 'final')
QC_model_name = Model_name

#Create a quality control Folder and check if the folder already exist
if os.path.exists(QC_model_folder+"/Quality Control") == False:
  os.makedirs(QC_model_folder+"/Quality Control")

if os.path.exists(QC_model_folder+"/Quality Control/Prediction"):
  rmtree(QC_model_folder+"/Quality Control/Prediction")
os.makedirs(QC_model_folder+"/Quality Control/Prediction")


model = models.CellposeModel(gpu=True, pretrained_model=QC_model_path, torch=True, diam_mean=30.0, net_avg=True, device=None, residual_on=True, style_on=True, concatenation=False)


# Here we need to make predictions

for name in os.listdir(Source_QC_folder):
  
  print("Performing prediction on: "+name)
  image = io.imread(Source_QC_folder+"/"+name) 

  short_name = os.path.splitext(name)
  masks, flows, styles = model.eval(image, diameter=Object_diameter, flow_threshold=Flow_threshold,cellprob_threshold=Cell_probability_threshold, channels=channels)
            
  os.chdir(QC_model_folder+"/Quality Control/Prediction")
  imsave(str(short_name[0])+".tif", masks, compress=ZIP_DEFLATED)  
  
# Here we start testing the differences between GT and predicted masks

with open(QC_model_folder+"/Quality Control/Quality_Control for "+QC_model_name+".csv", "w", newline='') as file:
  writer = csv.writer(file, delimiter=",")
  writer.writerow(["image","Prediction v. GT Intersection over Union", "false positive", "true positive", "false negative", "precision", "recall", "accuracy", "f1 score", "n_true", "n_pred", "mean_true_score", "mean_matched_score", "panoptic_quality"])  

# define the images

  for n in os.listdir(Source_QC_folder):
    
    if not os.path.isdir(os.path.join(Source_QC_folder,n)):
      print('Running QC on: '+n)
      test_input = io.imread(os.path.join(Source_QC_folder,n))
      test_prediction = io.imread(os.path.join(QC_model_folder+"/Quality Control/Prediction",n))
      test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, n))

      # Calculate the matching (with IoU threshold `thresh`) and all metrics

      stats = matching(test_ground_truth_image, test_prediction, thresh=0.5)
      

      #Convert pixel values to 0 or 255
      test_prediction_0_to_255 = test_prediction
      test_prediction_0_to_255[test_prediction_0_to_255>0] = 255

      #Convert pixel values to 0 or 255
      test_ground_truth_0_to_255 = test_ground_truth_image
      test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255


      # Intersection over Union metric

      intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)
      union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)
      iou_score =  np.sum(intersection) / np.sum(union)
      writer.writerow([n, str(iou_score), str(stats.fp), str(stats.tp), str(stats.fn), str(stats.precision), str(stats.recall), str(stats.accuracy), str(stats.f1), str(stats.n_true), str(stats.n_pred), str(stats.mean_true_score), str(stats.mean_matched_score), str(stats.panoptic_quality)])



df = pd.read_csv (QC_model_folder+"/Quality Control/Quality_Control for "+QC_model_name+".csv")
print(tabulate(df, headers='keys', tablefmt='psql'))



# ------------- For display ------------
print('--------------------------------------------------------------')
@interact
def show_QC_results(file = os.listdir(Source_QC_folder)):
  

  plt.figure(figsize=(25,5))
  if n_channel > 1:
    source_image = io.imread(os.path.join(Source_QC_folder, file))
  if n_channel == 1:
    source_image = io.imread(os.path.join(Source_QC_folder, file), as_gray = True)

  target_image = io.imread(os.path.join(Target_QC_folder, file), as_gray = True)
  prediction = io.imread(QC_model_folder+"/Quality Control/Prediction/"+file, as_gray = True)

  stats = matching(prediction, target_image, thresh=0.5)

  target_image_mask = np.empty_like(target_image)
  target_image_mask[target_image > 0] = 255
  target_image_mask[target_image == 0] = 0
  
  prediction_mask = np.empty_like(prediction)
  prediction_mask[prediction > 0] = 255
  prediction_mask[prediction == 0] = 0

  intersection = np.logical_and(target_image_mask, prediction_mask)
  union = np.logical_or(target_image_mask, prediction_mask)
  iou_score =  np.sum(intersection) / np.sum(union)

  norm = simple_norm(source_image, percent = 99)

  #Input
  plt.subplot(1,4,1)
  plt.axis('off')
  if n_channel > 1:
    plt.imshow(source_image)
  if n_channel == 1:
    plt.imshow(source_image, aspect='equal', norm=norm, cmap='magma', interpolation='nearest')
  plt.title('Input')

  #Ground-truth
  plt.subplot(1,4,2)
  plt.axis('off')
  plt.imshow(target_image_mask, aspect='equal', cmap='Greens')
  plt.title('Ground Truth')

  #Prediction
  plt.subplot(1,4,3)
  plt.axis('off')
  plt.imshow(prediction_mask, aspect='equal', cmap='Purples')
  plt.title('Prediction')

  #Overlay
  plt.subplot(1,4,4)
  plt.axis('off')
  plt.imshow(target_image_mask, cmap='Greens')
  plt.imshow(prediction_mask, alpha=0.5, cmap='Purples')
  plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(iou_score,3 )));
  plt.savefig(QC_model_folder+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)


# full_QC_model_path = QC_model_folder+'/'
# qc_pdf_export()



# **6. Using the trained model**

---

<font size = 4>In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive.

## **6.1 Generate prediction(s) from unseen dataset**
---

<font size = 4>The current trained model (from section 4.) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the model's name and path to use. Predicted output images are saved in your **Prediction_folder** folder as restored image stacks (ImageJ-compatible TIFF images).

<font size = 4>**`Data_folder_prediction`:** This folder should contain the images that you want to predict using the network that you will train.

<font size = 4>**`Result_folder`:** This folder is where the results from the predictions will be saved.

<font size = 4>**`Flow_threshold`:** This parameter controls the maximum allowed error of the flows for each mask. Increase this threshold if cellpose is not returning as many masks as you'd expect. Similarly, decrease this threshold if cellpose is returning too many ill-shaped masks. **Default value: 0.4**

<font size = 4>**`Cell_probability_threshold`:** The pixels greater than the Cell_probability_threshold are used to run dynamics and determine masks.  Decrease this threshold if cellpose is not returning as many masks as you'd expect. Similarly, increase this threshold if cellpose is returning too many masks, particularly from dim areas. **Default value: 0.0**

In [None]:
# -------------------------------------------------- 
#@markdown ###Data parameters
Data_folder_prediction = "" #@param {type:"string"}
Result_folder = "" #@param {type:"string"}

#@markdown ### Model parameters
#@markdown Do you want to use the model you just trained?
Use_the_current_trained_model = False #@param {type:"boolean"}
#@markdown Otherwise, please provide path to the model folder below
prediction_model_path = "" #@param {type:"string"}

#@markdown ### Segmentation parameters:
Object_diameter =  0#@param {type:"number"}
Flow_threshold = 0.4 #@param {type:"slider", min:0.1, max:1.1, step:0.1}
Cell_probability_threshold=0 #@param {type:"slider", min:-6, max:6, step:1}

# -------------------------------------------------- 
# TODO: allow to run on other file formats
# prediction_image_list = glob.glob(Data_folder_prediction+"/*.jpg")
# n_files = len(prediction_image_list)

# filenames_list = []
# for name in os.listdir(Data_folder_prediction):
#   if os.path.isfile(name):
#     filenames_list.append(os.path.splitext(name)[0])

# filenames_list = []
# for name in prediction_image_list:
#   filenames_list.append(os.path.splitext(os.path.basename(name))[0])

extension_list = ['*.jpg', '*.tif', '*.png']
prediction_image_list, filenames_list = get_image_list(Data_folder_prediction, extension_list)

n_files = len(prediction_image_list)
print("Total number of files: "+str(n_files))

if Use_the_current_trained_model:
  prediction_model_path = os.path.join(Model_folder, Model_name)


model_path = os.path.join(prediction_model_path, "final")
# TODO: Check the line below for file compatibility
channels=[0,0] 
model = models.CellposeModel(gpu=True, pretrained_model=model_path, torch=False, diam_mean=30.0, net_avg=True, device=None, residual_on=True, style_on=True, concatenation=False)


if (Object_diameter == 0):
  Object_diameter = None

masks_list = []
for i in tqdm(range(n_files)):
  img = io.imread(prediction_image_list[i])
  masks, flows, styles = model.eval(img, diameter=Object_diameter, flow_threshold=Flow_threshold,cellprob_threshold=Cell_probability_threshold, channels=channels)
  imsave(os.path.join(Result_folder, filenames_list[i]+'.tif'), masks)




# ------------- For display ------------
print('--------------------------------------------------------------')
@interact
def show_labels(filename = filenames_list):
  plt.figure(figsize=(13,10))

  img = io.imread(os.path.join(Data_folder_prediction, filename))
  mask = io.imread(os.path.join(Result_folder, filename+'.tif'))

  plt.subplot(121)
  plt.imshow(img, cmap='gray', interpolation='nearest')
  plt.title('Source image')
  plt.axis('off')
  plt.subplot(122)
  plt.imshow(mask, cmap='nipy_spectral', interpolation='nearest')
  plt.title('Label')
  plt.axis('off');


## **6.2. Download your predictions**
---

<font size = 4>**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name.

# **7. Version log**

---
<font size = 4>**v1.13.1**: 

* The notebook now can take a folder containing masks for pre-annotated data. It creates the corresponding json file to be imported into the trainer.

---
<font size = 4>**v1.13**: 

* The section 1 and 2 are now swapped for better export of *requirements.txt*. 
* This version also now includes built-in version check and the version log that you're reading now. 
* This version also specifically pulls StarDist packages version 0.6.2, due to current incompatibilities with the newest versions.
* Better data input compatibilities.

---

---
#**Thank you for using Interactive segmentation - Cellpose 2D!**