<a href="https://colab.research.google.com/github/Pablo1990/3D-deep-segmentation-protocol/blob/main/3D_deep_segmentation_protocol.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 3D deep segmentation protocol
#### By Paci and Vicente-Munuera et al., 2025 https://arxiv.org/abs/2501.19203

Here, we explain the 3D segmentation protocol for deep tissues (especifically, the *Drosophila* wing disc) with the following steps:
0.   [Installation setup](#scrollTo=IvyuR08OZfw4)
1.   [Initial segmentation: cellpose](#scrollTo=R7Zz3cKE6UMG)
2.   [Automated corrections: Tracking cells in 3D with TrackMate in FIJI](#scrollTo=3-up0gcY_a9O)
3.   [Manual segmentation](#scrollTo=IyW0d9L-lV3M)
4.   [Refining the segmentation: Cellpose fine-tuning](#scrollTo=L7DxZhik4aQd)

This notebook was inspired by Cellpose 2.0 notebook (https://github.com/MouseLand/cellpose) by Carsen Stringer et al. (https://mouseland.github.io/) and the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki).

# 0. Installation setup

We will first install cellpose and other dependencies, check the GPU is working, and mount google drive to get your models and images.

## Local installation

Use the following instructions outside google collab:

In [None]:
# Create an environment with python 3.10.15
conda create --name cellpose_3d python=3.10.15
# Activate that environment
conda activate cellpose_3d
# Install cellpose with Graphical User Interface
pip install cellpose[all]==3.1.0 matplotlib==3.7.3 plotly scikit-learn gdown notebook

Now, you can connect google colab to your own computer in case you have a GPU by following: https://research.google.com/colaboratory/local-runtimes.html

## Colab installation

Install cellpose -- by default the torch GPU version is installed in COLAB notebook.

In [None]:
!pip install cellpose[all]==3.1.0 matplotlib==3.7.3 plotly scikit-learn

### All is working?

Check CUDA version and that GPU is working in cellpose and import other libraries.

In [None]:
!nvcc --version
!nvidia-smi

import os, shutil
import numpy as np
import matplotlib.pyplot as plt
from cellpose import core, utils, io, models, metrics
from glob import glob

use_GPU = core.use_gpu()
yn = ['NO', 'YES']
print(f'>>> GPU activated? {yn[use_GPU]}')

zsh:1: command not found: nvcc
zsh:1: command not found: nvidia-smi
>>> GPU activated? YES


## Mount google drive

Mount your google drive to access all your image files, segmentations, and custom models. This also ensures that any models you train are saved to your google drive. If you'd like to try out the notebook without your own files, please download the sample images provided (optional step in Setup below).

In [None]:

#@markdown ###Run this cell to connect your Google Drive to Colab

#@markdown * Click on the URL.

#@markdown * Sign in your Google Account.

#@markdown * Copy the authorization code.

#@markdown * Enter the authorization code.

#@markdown * Click on "Files" site on the right. Refresh the site. Your Google Drive folder should now be available here as "drive".

#mounts user's Google Drive to Google Colab.

from google.colab import drive
drive.mount('/content/gdrive')




Mounted at /content/gdrive


## Download sample images (optional)

If you want to test this protocol with some sample images, run the next code. These images are described [here](https://www.ebi.ac.uk/bioimage-archive/galleries/S-BIAD843-ai.html).

In [None]:
# !rm -rf labelled_data/

In [None]:
import gdown
from natsort import natsorted

!rm -rf labelled_data/

# Download data from google drive
url = 'https://drive.google.com/uc?id=1gXmuJlNYXLxAZypwEb2dhOs0t9LxL04p'
gdown.download(url, 'labelled_data.tar.gz', quiet=False)

!tar -xzvf labelled_data.tar.gz
!rm labelled_data.tar.gz

# Copy folder to 'initial_segmentation'
!cp -r labelled_data/raw labelled_data/initial_segmentation/
!cp -r labelled_data/raw labelled_data/denoised_raw/
!cp -r labelled_data/raw labelled_data/improved_model/

Downloading...
From (original): https://drive.google.com/uc?id=1gXmuJlNYXLxAZypwEb2dhOs0t9LxL04p
From (redirected): https://drive.google.com/uc?id=1gXmuJlNYXLxAZypwEb2dhOs0t9LxL04p&confirm=t&uuid=85dbc97c-91c9-40cb-9470-b06a6d0a3a27
To: /Users/wei-tunghsu/Documents/GitHub/3D-deep-segmentation-protocol/labelled_data.tar.gz
100%|██████████| 46.7M/46.7M [00:29<00:00, 1.60MB/s]


x labelled_data/
x labelled_data/.DS_Store
x labelled_data/segmented/
x labelled_data/raw/
x labelled_data/raw/WD3.2_21-03_WT_MP.tif
x labelled_data/raw/WD1_15-02_WT_confocalonly.tif
x labelled_data/raw/WD2.1_21-02_WT_confocalonly.tif
x labelled_data/raw/WD1.1_17-03_WT_MP.tif
x labelled_data/segmented/WD1.1_17-03_WT_MP_segmented.tif
x labelled_data/segmented/WD3.2_21-03_WT_MP_segmented.tif
x labelled_data/segmented/WD2.1_21-02_WT_confocalonly_segmented.tif
x labelled_data/segmented/WD1_15-02_WT_confocalonly_segmented.tif


# 1. Initial segmentation: Cellpose

Cellpose is a deep learning software that can segment cells in 2D and 3D.

**We highly recommend to do the initial segmentation using the Cellpose graphical user interface (gui).**


## (LOCAL ONLY) Graphical User Interface (GUI) with a single image

Open one of the 3D images and select the best parameters based on visual inspection.

In [None]:
!python -m cellpose --Zstack

2025-12-16 11:34:06,726 [INFO] WRITING LOG OUTPUT TO /Users/wei-tunghsu/.cellpose/run.log
2025-12-16 11:34:06,726 [INFO] 
cellpose version: 	3.1.0 
platform:       	darwin 
python version: 	3.10.15 
torch version:  	2.8.0
2025-12-16 11:34:07,260 [INFO] ** TORCH MPS version installed and working. **


## Code to run with all images

In [None]:
# model name and path

# model name and path
#@markdown ###Name of the pretrained model:
from cellpose import models
initial_model = "cyto3" #@param ["cyto", "cyto3","nuclei","tissuenet_cp3", "livecell_cp3", "yeast_PhC_cp3", "yeast_BF_cp3", "bact_phase_cp3", "bact_fluor_cp3", "deepbacs_cp3", "scratch"]

#@markdown ###Path to images:

input_dir = "/Users/wei-tunghsu/Documents/GitHub/3D-deep-segmentation-protocol/labelled_data/initial_segmentation" #@param {type:"string"}

#@markdown ###Channel Parameters:

Channel_to_use_for_segmentation = "Grayscale" #@param ["Grayscale", "Blue", "Green", "Red"]

# Here we match the channel to number
if Channel_to_use_for_segmentation == "Grayscale":
  chan = 0
elif Channel_to_use_for_segmentation == "Blue":
  chan = 3
elif Channel_to_use_for_segmentation == "Green":
  chan = 2
elif Channel_to_use_for_segmentation == "Red":
  chan = 1

#@markdown ### GPU (default) or CPU:

use_GPU = True #@param {type:"boolean"}

#@markdown ### Segmentation parameters:

#@markdown Diameter of cells (set to zero to use diameter from training set):
diameter =  60#@param {type:"number"}
#@markdown Threshold on cellprob output to seed cell masks (set lower to include more pixels or higher to include fewer, e.g. in range from (-6, 6)):
cellprob_threshold=0 #@param {type:"slider", min:-6, max:6, step:1}
#@markdown Stitch 2D masks into a 3D volume using a stitch_threshold on IOU:
stitch_threshold=0.05 #@param {type:"slider", min:0, max:1, step:0.01}
#@markdown Smooth flows with gaussian filter of this stddev
dP_smooth=0.0 #@param {type:"slider", min:0, max:1, step:0.01}
#@markdown Volumetric stacks do not always have the same sampling in XY as they do in Z
anisotropy=1.0 #@param {type:"slider", min:0, max:2, step:0.01}

In [None]:
if use_GPU:
  run_str = f'python -m cellpose --use_gpu --save_tif --Zstack --verbose --dir {input_dir} --pretrained_model {initial_model} --chan {chan} --diameter {diameter} --stitch_threshold {stitch_threshold} --dP_smooth {dP_smooth} --anisotropy {anisotropy} --cellprob_threshold {cellprob_threshold}'
else:
  run_str = f'python -m cellpose --save_tif --Zstack --verbose --dir {input_dir} --pretrained_model {initial_model} --chan {chan} --diameter {diameter} --stitch_threshold {stitch_threshold} --dP_smooth {dP_smooth} --anisotropy {anisotropy} --cellprob_threshold {cellprob_threshold}'
print(run_str)
!$run_str

python -m cellpose --use_gpu --save_tif --Zstack --verbose --dir /Users/wei-tunghsu/Documents/GitHub/3D-deep-segmentation-protocol/labelled_data/initial_segmentation --pretrained_model cyto3 --chan 0 --diameter 60 --stitch_threshold 0.05 --dP_smooth 0.0 --anisotropy 1.0
2025-12-16 11:43:30,251 [INFO] WRITING LOG OUTPUT TO /Users/wei-tunghsu/.cellpose/run.log
2025-12-16 11:43:30,251 [INFO] 
cellpose version: 	3.1.0 
platform:       	darwin 
python version: 	3.10.15 
torch version:  	2.8.0
2025-12-16 11:43:30,288 [INFO] ** TORCH MPS version installed and working. **
2025-12-16 11:43:30,288 [INFO] >>>> using GPU (MPS)
2025-12-16 11:43:30,291 [INFO] >>>> running cellpose on 4 images using chan_to_seg GRAY and chan (opt) NONE
2025-12-16 11:43:30,291 [INFO] ** TORCH MPS version installed and working. **
2025-12-16 11:43:30,291 [INFO] >>>> using GPU (MPS)
2025-12-16 11:43:30,291 [INFO] >> cyto3 << model set to be used
2025-12-16 11:43:30,382 [INFO] >>>> loading model /Users/wei-tunghsu/.cellp

### Visualising images

In [None]:
import matplotlib
import matplotlib.pyplot as plt
from cellpose import io
import numpy as np
from scipy import ndimage

def visualize_3d_sections(image, masks, segmented=True, num_sections=3):
    """
    Visualizes different random sections of the 3D image and its segmentation.

    Args:
        image: A 3D numpy array representing the image.
        masks: A 3D numpy array representing the cell masks.
        num_sections: The number of random sections to visualize.
    """

    z_dim = image.shape[0]

    # Generate 'num_sections' random numbers
    random_sections = np.random.randint(0, z_dim, num_sections)

    # Sort the random numbers in ascending order
    random_sections = np.sort(random_sections)

    # Create a colormap for all the sections
    cmap = matplotlib.colormaps.get_cmap('prism')

    for id in range(num_sections):
        z_slice = random_sections[id]
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(image[z_slice], cmap='gray')
        plt.title(f"Image - Z Slice: {z_slice}")

        plt.subplot(1, 2, 2)
        if segmented:
          plt.imshow(masks[z_slice], cmap)
          plt.title(f"Segmentation - Z Slice: {z_slice}")
        else:
          plt.imshow(masks[z_slice], cmap='gray')
          plt.title(f"Image - Z Slice: {z_slice}")
        plt.show()

# Get images and masks
files = io.get_image_files(input_dir, '_cp_masks')
images = [io.imread(f) for f in files]
masks = [io.imread(f.replace('.tif', '_cp_masks.tif')) for f in files]
visualize_3d_sections(images[3], masks[3], num_sections=1)

100%|██████████| 105/105 [00:00<00:00, 9733.28it/s]
100%|██████████| 36/36 [00:00<00:00, 7201.21it/s]
100%|██████████| 38/38 [00:00<00:00, 10122.16it/s]
100%|██████████| 60/60 [00:00<00:00, 8313.78it/s]


FileNotFoundError: [Errno 2] No such file or directory: '/Users/wei-tunghsu/Documents/GitHub/3D-deep-segmentation-protocol/labelled_data/initial_segmentation/WD1.1_17-03_WT_MP_cp_masks.tif'

## Segmentation evaluation metrics: biology-based
Common metrics used to evaluate the quality of a segmentation are: IoU (Intersection over Union), Dice coefficient, ... However, these metrics are not always suitable for biological images. We have developed a new metric that is more biologically relevant. It also helped us to improve the segmentation results from our manual annotations.




In [None]:
from cellpose import io
import numpy as np
from scipy import ndimage


def calculate_cell_persistence_score(mask_img, min_percentage=85):
  """
  Count the number of cells that are present in a 'min_percentage' of slices.
  Developed by Giulia Paci
  @mask_img: 3D mask image
  @min_percentage: minimum percentage of slices that a cell must be present in to be considered a good cell
  return: number of good cells and number of bad cells
  """
  z_planes = mask_img.shape[0]

  # Minimum of number of slices for a cell to be correct
  target_n_planes = (min_percentage / 100 ) * z_planes
  #print(f'Minimum number of z-planes is: {target_n_planes}')

  # Count the number of good cells
  unique_ids = np.unique(mask_img)
  count_good = 0
  count_bad = 0

  # Loop
  for cell_id in unique_ids:
    if cell_id == 0:
      continue

    # Get the voxels of the current cell
    current_img = mask_img == cell_id

    # Get the position of the voxels
    binary_img_pos = np.where(current_img)

    # Check if they are connected by using connected components
    _, num_objects = ndimage.label(current_img)

    if num_objects > 2:
      count_bad = count_bad + 1
      continue

    # Get only the unique Z position of the voxels
    unique_z_position = np.unique(binary_img_pos[0])

    # Count the number of slices that the cell is present in
    if len(unique_z_position) > target_n_planes:
        count_good = count_good + 1
    else:
        count_bad = count_bad + 1

  return count_good, count_bad


# Get evaluation of segmentation
files = io.get_image_files(input_dir, '_cp_masks')
for file in files:
  print(f'File name: {file}')
  mask = io.imread(file.replace('.tif', '_cp_masks.tif'))
  good_cells, bad_cells = calculate_cell_persistence_score(mask)
  print(f'Number of good cells: {good_cells} and bad cells: {bad_cells}')



###Comparison 0: Cyto3 without extras

In [None]:
!cp -r labelled_data/raw labelled_data/initial_segmentation_only_cyto3/
input_dir = "labelled_data/initial_segmentation_only_cyto3"
if use_GPU:
  run_str = f'python -m cellpose --use_gpu --save_tif --Zstack --verbose --dir {input_dir} --pretrained_model cyto3 --chan 0 --diameter {diameter}'
else:
  run_str = f'python -m cellpose --save_tif --Zstack --verbose --dir {input_dir} --pretrained_model cyto3 --chan 0 --diameter {diameter}'
print(run_str)
!$run_str

# Get evaluation of segmentation
files = io.get_image_files(input_dir, '_cp_masks')
for file in files:
  print(f'File name: {file}')
  img = io.imread(file)
  mask = io.imread(file.replace('.tif', '_cp_masks.tif'))
  good_cells, bad_cells = calculate_cell_persistence_score(mask)
  print(f'Number of good cells: {good_cells} and bad cells: {bad_cells}')

### Comparison 1: Stitch threshold 0.5

In [None]:
!cp -r labelled_data/raw labelled_data/initial_segmentation_stitch_0_5/
input_dir = "labelled_data/initial_segmentation_stitch_0_5"
if use_GPU:
  run_str = f'python -m cellpose --use_gpu --save_tif --Zstack --verbose --dir {input_dir} --pretrained_model cyto3 --chan 0 --diameter {diameter} --stitch_threshold 0.5'
else:
  run_str = f'python -m cellpose --save_tif --Zstack --verbose --dir {input_dir} --pretrained_model cyto3 --chan 0 --diameter {diameter} --stitch_threshold 0.5'
print(run_str)
!$run_str

# Get evaluation of segmentation
files = io.get_image_files(input_dir, '_cp_masks')
for file in files:
  print(f'File name: {file}')
  img = io.imread(file)
  mask = io.imread(file.replace('.tif', '_cp_masks.tif'))
  good_cells, bad_cells = calculate_cell_persistence_score(mask)
  print(f'Number of good cells: {good_cells} and bad cells: {bad_cells}')


### Comparison 2: Final best

In [None]:
!cp -r labelled_data/raw labelled_data/initial_segmentation_stitch_0_0_5/
input_dir = "labelled_data/initial_segmentation_stitch_0_0_5"
if use_GPU:
  run_str = f'python -m cellpose --use_gpu --save_tif --Zstack --verbose --dir {input_dir} --pretrained_model cyto3 --chan 0 --diameter {diameter} --stitch_threshold 0.05'
else:
  run_str = f'python -m cellpose --save_tif --Zstack --verbose --dir {input_dir} --pretrained_model cyto3 --chan 0 --diameter {diameter} --stitch_threshold 0.05'

print(run_str)
!$run_str

# Get evaluation of segmentation
files = io.get_image_files(input_dir, '_cp_masks')
for file in files:
  print(f'File name: {file}')
  img = io.imread(file)
  mask = io.imread(file.replace('.tif', '_cp_masks.tif'))
  good_cells, bad_cells = calculate_cell_persistence_score(mask)
  print(f'Number of good cells: {good_cells} and bad cells: {bad_cells}')

### Comparing Cell Persistence Score, DICE and Intersection Over Union


In [None]:
import numpy as np
import os

def calculate_metrics(pred, gt, pred_id, gt_id):
    pred_region = (pred == pred_id)
    gt_region = (gt == gt_id)
    pred_region_num_pixels = len(np.where(pred_region)[0])
    gt_region_num_pixels = len(np.where(gt_region)[0])
    intersection = len(np.where(pred_region & gt_region)[0])
    union = len(np.where(pred_region | gt_region)[0])
    dice = (2 * intersection) / (pred_region_num_pixels + gt_region_num_pixels) if (pred_region_num_pixels + gt_region_num_pixels) > 0 else 0
    iou = intersection / union if union > 0 else 0
    return dice, iou

def evaluate_3d_instances(pred, gt, iou_thresh=0.5):

    pred_ids = set(np.unique(pred)) - {0}
    gt_ids = set(np.unique(gt)) - {0}

    tp, fp, fn = 0, 0, 0
    matched_pred = set()
    instance_ious = []
    instance_dices = []

    for gt_id in gt_ids:
      best_iou, best_dice, best_pred = 0, 0, None
      for pred_id in pred_ids:
          dice, iou = calculate_metrics(pred, gt, pred_id, gt_id)
          if iou > best_iou:
              best_iou, best_dice, best_pred = iou, dice, pred_id

      instance_ious.append(best_iou)
      instance_dices.append(best_dice)

      if best_iou >= iou_thresh and best_pred not in matched_pred:
          tp += 1
          matched_pred.add(best_pred)

    fp = len(pred_ids) - len(matched_pred)
    fn = len(gt_ids) - tp

    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    iou = np.mean(instance_ious) if instance_ious else 0
    dice = np.mean(instance_dices) if instance_dices else 0

    return precision, recall, f1, iou, dice

# Plot with the following folders
input_folders = ['labelled_data/initial_segmentation_only_cyto3', 'labelled_data/initial_segmentation_stitch_0_5', 'labelled_data/initial_segmentation_stitch_0_0_5'] #'labelled_data/manual_corrections_2d', 'labelled_data/trackmate',
ground_truth_folder = 'labelled_data/segmented'

for input_folder in input_folders:
  files = io.get_image_files(input_folder, '_cp_masks')
  gt_files = io.get_image_files(ground_truth_folder, '.tif')
  for file, gt_file in zip(files, gt_files):
      print(f'File name: {file}')
      print(f'GT file name: {gt_file}')
      mask = io.imread(file.replace('.tif', '_cp_masks.tif'))
      good_cells, bad_cells = calculate_cell_persistence_score(mask)
      print(f'Number of good cells: {good_cells} and bad cells: {bad_cells}')
      precision, recall, f1, iou, dice = evaluate_3d_instances(mask, io.imread(gt_file))
      print(f'Precision: {precision}, Recall: {recall}, F1: {f1}, IOU: {iou}, DICE: {dice}')


### Ground truth

We have segmented our data.

In [None]:
labelled_data = "labelled_data/segmented"
files = io.get_image_files(labelled_data, '_cp_masks')

# Get evaluation of segmentation
for file in files:
  print(f'File name: {file}')
  mask = io.imread(file)
  good_cells, bad_cells = calculate_cell_persistence_score(mask)
  print(f'Number of good cells: {good_cells} and bad cells: {bad_cells}')

# 2. Automated corrections: Tracking cells in 3D with TrackMate in FIJI
We still had issues with our segmentation results, especially when cells were touching each other. Therefore, we used TrackMate to track cells in 3D and then stitched the cells together. This allowed us to obtain a more accurate segmentation of the cells. The process is explained in the manuscript associated with this notebook.

# 3. Manual segmentation

Follow: https://napari.org/stable/tutorials/fundamentals/installation.html#napari-installation

In [None]:
!conda create -y -n napari-env -c conda-forge python=3.10
!conda activate napari-env
!python -m pip install "napari[all]"
!python -m pip install devbio-napari
!napari

# 4. Refining the segmentation: Cellpose fine-tuning
Cellpose fine-tuning is a feature that allows you to improve the segmentation results by training a new model on your own data.

### From 3D images to 2D images

In [None]:
# Code developed by Alvaro Miranda de Larra and Pablo Vicente-Munuera
import os
import tifffile as tiff

def iterative_image_splicer(input_dir, output_dir, segmented_input=False):

  # Create the output directory if it doesn't exist
  os.makedirs(output_dir, exist_ok=True)

  # Get a list of all .tif files in the input directory
  tif_files = [f for f in os.listdir(input_dir) if f.endswith('.tif')]

  # Iterate over each 3D image
  for tif_file in tif_files:
    if not tif_file.startswith('.'):
      # Load the multi-directory TIFF image
      with tiff.TiffFile(os.path.join(input_dir, tif_file)) as tif:
          image = tif.asarray()

      # Get the shape of the 3D image
      z, y, x = image.shape
      print(image.shape)
      # Generate 2D images along XY, XZ, and YZ coordinates

      for z_coord in range(z):
          xy_image = image[z_coord, :, :]  # XY plane at the current Z coordinate

          # Save the 2D images with appropriate names
          base_name = os.path.splitext(tif_file)[0]
          # Remove '_segmented' from base_name
          base_name = base_name.replace('_segmented', '')
          if segmented_input:
            tiff.imwrite(os.path.join(output_dir, f'{base_name}_XY_Z{z_coord}_masks.tif'), xy_image)
          else:
            tiff.imwrite(os.path.join(output_dir, f'{base_name}_XY_Z{z_coord}.tif'), xy_image)

      for xy_coord in range(x):
          xz_image = image[:, :, xy_coord]  # XZ plane at the current Y coordinate
          yz_image = image[:, xy_coord, :]  # YZ plane at the current X coordinate

          if segmented_input:
            tiff.imwrite(os.path.join(output_dir, f'{base_name}_XZ_Y{xy_coord}_masks.tif'), xz_image)
            tiff.imwrite(os.path.join(output_dir, f'{base_name}_YZ_X{xy_coord}_masks.tif'), yz_image)
          else:
            tiff.imwrite(os.path.join(output_dir, f'{base_name}_XZ_Y{xy_coord}.tif'), xz_image)
            tiff.imwrite(os.path.join(output_dir, f'{base_name}_YZ_X{xy_coord}.tif'), yz_image)

!rm -rf labelled_data_2D/
iterative_image_splicer('labelled_data/raw/', 'labelled_data_2D')
iterative_image_splicer('labelled_data/segmented/', 'labelled_data_2D', segmented_input=True)

what the training images look like + their labels

In [None]:
from natsort import natsorted
from glob import glob

train_files = natsorted([f for f in glob('labelled_data_2D/*.tif')
                        if '_masks' not in f])
train_seg = natsorted(glob('labelled_data_2D/*_masks.tif'))

num_images_to_show = 5

# Generate 'num_sections' random numbers
random_sections = np.random.randint(0, len(train_files), num_images_to_show)

# Visualize a few training and segmentation images
for k,f in enumerate(random_sections):
    img = io.imread(train_files[f])
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(img, cmap='gray')
    plt.title(f"Image - name: {train_files[f]}")

    # Get the corresponding segmentation image
    seg = io.imread(train_seg[f])
    plt.subplot(1, 2, 2)
    plt.imshow(seg, cmap='prism')
    plt.show()


### Split into training and set test


In [None]:
!rm -rf test/
!rm -rf train/

# Divide sets into training and set test
from sklearn.model_selection import train_test_split

train_files = natsorted([f for f in glob('labelled_data_2D/*.tif')
                        if '_masks.tif' not in f])
train_files, test_files = train_test_split(train_files, test_size=0.2, random_state=42)

# Save files from train into 'train'
os.makedirs('train', exist_ok=True)
for f in train_files:
    shutil.copy(f, 'train')
    # Get the '_mask' from the file f
    shutil.copy(f.replace('.tif', '_masks.tif'), 'train')

# Save files from test into 'test'
os.makedirs('test', exist_ok=True)
for f in test_files:
    shutil.copy(f, 'test')
    # Get the '_mask' from the file f
    shutil.copy(f.replace('.tif', '_masks.tif'), 'test')

### Train model on manual annotations

Skip this step if you already have a pretrained model.

Fill out the form below with the paths to your data and the parameters to start training.

### Training parameters

<font size = 4> **Paths for training, predictions and results**


<font size = 4>**`train_dir:`, `test_dir`:** These are the paths to your folders train_dir (with images and masks of training images) and test_dir (with images and masks of test images). You can leave the test_dir blank, but it's recommended to have some test images to check the model's performance. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.

<font size = 4>**`initial_model`:** Choose a model from the cellpose [model zoo](https://cellpose.readthedocs.io/en/latest/models.html#model-zoo) to start from.

<font size = 4>**`model_name`**: Enter the path where your model will be saved once trained (for instance your result folder).

<font size = 4>**Training parameters**

<font size = 4>**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. At least 100 epochs are recommended, but sometimes 250 epochs are necessary, particularly from scratch. **Default value: 100**



In [None]:
#@markdown ###Path to images and masks:

train_dir = "train" #@param {type:"string"}
test_dir = "test" #@param {type:"string"}
#Define where the patch file will be saved
base = "/content"

# model name and path
#@markdown ###Name of the pretrained model to start from and new model name:
from cellpose import models
initial_model = "cyto3" #@param ["cyto", "cyto3","nuclei","tissuenet_cp3", "livecell_cp3", "yeast_PhC_cp3", "yeast_BF_cp3", "bact_phase_cp3", "bact_fluor_cp3", "deepbacs_cp3", "scratch"]
model_name = "CP_improved_100" #@param {type:"string"}
use_GPU = True #@param {type:"boolean"}

# other parameters for training.
#@markdown ###Training Parameters:
#@markdown Number of epochs:
n_epochs =  100#@param {type:"number"}

Channel_to_use_for_training = "Grayscale" #@param ["Grayscale", "Blue", "Green", "Red"]

# @markdown ###If you have a secondary channel that can be used for training, for instance nuclei, choose it here:

Second_training_channel= "None" #@param ["None", "Blue", "Green", "Red"]


#@markdown ###Advanced Parameters

Use_Default_Advanced_Parameters = True #@param {type:"boolean"}
#@markdown ###If not, please input:
learning_rate = 0.1 #@param {type:"number"}
weight_decay = 0.0001 #@param {type:"number"}

if (Use_Default_Advanced_Parameters):
  print("Default advanced parameters enabled")
  learning_rate = 0.1
  weight_decay = 0.0001

#here we check that no model with the same name already exist, if so delete
model_path = train_dir + 'models/'
if os.path.exists(model_path+'/'+model_name):
  print("!! WARNING: "+model_name+" already exists and will be deleted in the following cell !!")

if len(test_dir) == 0:
  test_dir = None

# Here we match the channel to number
if Channel_to_use_for_training == "Grayscale":
  chan = 0
elif Channel_to_use_for_training == "Blue":
  chan = 3
elif Channel_to_use_for_training == "Green":
  chan = 2
elif Channel_to_use_for_training == "Red":
  chan = 1


if Second_training_channel == "Blue":
  chan2 = 3
elif Second_training_channel == "Green":
  chan2 = 2
elif Second_training_channel == "Red":
  chan2 = 1
elif Second_training_channel == "None":
  chan2 = 0

if initial_model=='scratch':
  initial_model = 'None'

Default advanced parameters enabled


### Train new model

Using settings from form above, train model in notebook.

In [None]:
if use_GPU:
  run_str = f'python -m cellpose --use_gpu --verbose --train --dir {train_dir} --pretrained_model {initial_model} --chan {chan} --n_epochs {n_epochs} --learning_rate {learning_rate} --weight_decay {weight_decay} --model_name_out {model_name}'
else:
  run_str = f'python -m cellpose --verbose --train --dir {train_dir} --pretrained_model {initial_model} --chan {chan} --n_epochs {n_epochs} --learning_rate {learning_rate} --weight_decay {weight_decay} --model_name_out {model_name}'

if test_dir is not None:
    run_str += f' --test_dir {test_dir}'
print(run_str)
!$run_str

## Evaluate on test data (optional)

In [None]:
# model name and path

# model name and path
#@markdown ###Name of the pretrained model:
from cellpose import models
initial_model = "CP_improved_100" #@param {type:"string"}

#@markdown ###Path to images:

input_dir = "labelled_data/improved_model" #@param {type:"string"}

#@markdown ###Channel Parameters:

Channel_to_use_for_segmentation = "Grayscale" #@param ["Grayscale", "Blue", "Green", "Red"]

# Here we match the channel to number
if Channel_to_use_for_segmentation == "Grayscale":
  chan = 0
elif Channel_to_use_for_segmentation == "Blue":
  chan = 3
elif Channel_to_use_for_segmentation == "Green":
  chan = 2
elif Channel_to_use_for_segmentation == "Red":
  chan = 1

#@markdown ### Segmentation parameters:

#@markdown Diameter of cells (set to zero to use diameter from training set):
diameter =  60#@param {type:"number"}
#@markdown Threshold on cellprob output to seed cell masks (set lower to include more pixels or higher to include fewer, e.g. in range from (-6, 6)):
cellprob_threshold=0 #@param {type:"slider", min:-6, max:6, step:1}
#@markdown Stitch 2D masks into a 3D volume using a stitch_threshold on IOU:
stitch_threshold=0.05 #@param {type:"slider", min:0, max:1, step:0.01}
#@markdown Smooth flows with gaussian filter of this stddev
dP_smooth=0.0 #@param {type:"slider", min:0, max:1, step:0.01}
#@markdown Volumetric stacks do not always have the same sampling in XY as they do in Z
anisotropy=1.0 #@param {type:"slider", min:0, max:2, step:0.01}

In [None]:
if use_GPU:
  run_str = f'python -m cellpose --use_gpu --save_tif --Zstack --verbose --dir {input_dir} --pretrained_model train/models/{initial_model} --chan {chan} --diameter {diameter} --stitch_threshold {stitch_threshold} --dP_smooth {dP_smooth} --anisotropy {anisotropy} --cellprob_threshold {cellprob_threshold}'
else:
  run_str = f'python -m cellpose --save_tif --Zstack --verbose --dir {input_dir} --pretrained_model train/models/{initial_model} --chan {chan} --diameter {diameter} --stitch_threshold {stitch_threshold} --dP_smooth {dP_smooth} --anisotropy {anisotropy} --cellprob_threshold {cellprob_threshold}'

print(run_str)
!$run_str

In [None]:
# Get evaluation of segmentation
files = io.get_image_files(input_dir, '_cp_masks')
for file in files:
  print(f'File name: {file}')
  mask = io.imread(file.replace('.tif', '_cp_masks.tif'))
  good_cells, bad_cells = calculate_cell_persistence_score(mask)
  print(f'Number of good cells: {good_cells} and bad cells: {bad_cells}')