<a href="https://colab.research.google.com/github/MouseLand/cellpose/blob/main/notebooks/train_Cellpose-SAM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Cellpose-SAM: superhuman generalization for cellular segmentation

Marius Pachitariu, Michael Rariden, Carsen Stringer

[paper](https://www.biorxiv.org/content/10.1101/2025.04.28.651001v1) | [code](https://github.com/MouseLand/cellpose)

This notebook shows how to process your own 2D or 3D images, saved on Google Drive.

This notebook is adapted from the notebook by Pradeep Rajasekhar, inspired by the [ZeroCostDL4Mic notebook series](https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki).

### Install Cellpose-SAM


In [None]:
!pip install git+https://www.github.com/mouseland/cellpose.git

Check GPU and instantiate model - will download weights.

In [None]:
import numpy as np
from cellpose import models, core, io, plot
from pathlib import Path
from tqdm import trange
import matplotlib.pyplot as plt

io.logger_setup() # run this to get printing of progress

#Check if colab notebook instance has GPU access
if core.use_gpu()==False:
  raise ImportError("No GPU access, change your runtime")

# models.CellposeModel(pretrained_model='/full/path/to/model')
model = models.CellposeModel(gpu=True)

Input directory with your images (if you have them, otherwise use sample images):

In [None]:
# # *** change to your google drive folder path ***
# train_dir = "/content/gdrive/MyDrive/PATH-TO-FILES/"
# if not Path(train_dir).exists():
#   raise FileNotFoundError("directory does not exist")

# test_dir = None # optionally you can specify a directory with test files

# # *** change to your mask extension ***
# masks_ext = "_seg.npy"
# # ^ assumes images from Cellpose GUI, if labels are tiffs, then "_masks.tif"

# # list all files
# files = [f for f in Path(train_dir).glob("*") if "_masks" not in f.name and "_flows" not in f.name and "_seg" not in f.name]

# if(len(files)==0):
#   raise FileNotFoundError("no files found, did you specify the correct folder and extension?")
# else:
#   print(f"{len(files)} files in folder:")

# for f in files:
#   print(f.name)

### Sample images (optional)

You can use our sample images instead of mounting your google drive

In [None]:
from natsort import natsorted
from cellpose import utils
from pathlib import Path

# url = "https://drive.google.com/uc?id=1HXpLczf7TPCdI1yZY5KV3EkdWzRrgvhQ"
# utils.download_url_to_file(url, "human_in_the_loop.zip")

# !unzip human_in_the_loop

# train_dir = "human_in_the_loop/train/"
# test_dir = "human_in_the_loop/test/"

# masks_ext = "_seg.npy"


## Train new model

In [None]:
import os
from PIL import Image
from sklearn.model_selection import train_test_split

# 設定路徑
image_dir = '/kaggle/input/sartorius-cell-instance-segmentation/train'  # 存放 xxx.png 的資料夾
mask_dict = np.load('/kaggle/input/all-instance-masks-npy/all_instance_masks.npy', allow_pickle=True).item()

image_arrays = []
mask_arrays = []

# 確保按照檔名排序（可選）
image_filenames = sorted(os.listdir(image_dir))

for filename in image_filenames:
    if not filename.endswith('.png'):
        continue

    image_id = os.path.splitext(filename)[0]  # 取得 'xxx' 作為 key

    # 讀圖並轉為 numpy array（轉成 RGB）
    img_path = os.path.join(image_dir, filename)
    img = Image.open(img_path).convert('RGB')
    img_np = np.array(img)

    # 找到對應的 mask
    if image_id not in mask_dict:
        print(f"⚠️ 找不到 {image_id} 的 mask，略過。")
        continue

    mask = mask_dict[image_id]  # mask 是 np.uint16

    image_arrays.append(img_np)
    mask_arrays.append(mask)

    if len(image_arrays) % 50 == 0:
        print(f"已處理 {len(image_arrays)} 張圖像")

train_images, val_images, train_masks, val_masks = train_test_split(
    image_arrays, mask_arrays, test_size=0.2, random_state=63
)

print(len(train_images), len(val_images), len(train_masks), len(val_masks))

In [None]:
from cellpose import train


model_name = "new_model"

# default training params
n_epochs = 20
learning_rate = 1e-5
weight_decay = 0.1
batch_size = 2

# get files
# output = io.load_train_test_data(train_dir, test_dir, mask_filter=masks_ext)
# train_data, train_labels, _, test_data, test_labels, _ = output


# 轉為 numpy array 儲存
# image_array_np = np.array(image_arrays)
# mask_array_np = np.array(mask_arrays)


# (not passing test data into function to speed up training)

new_model_path, train_losses, test_losses = train.train_seg(model.net,
                                                            train_data=train_images,
                                                            train_labels=train_masks,
                                                            test_data =val_images,
                                                            test_labels = val_masks,
                                                            batch_size=batch_size,
                                                            n_epochs=n_epochs,
                                                            min_train_masks=0,
                                                            learning_rate=learning_rate,
                                                            weight_decay=weight_decay,
                                                            nimg_per_epoch=max(2, len(train_images)), # can change this
                                                            model_name=model_name)


## Evaluate on test data (optional)

If you have test data, check performance

In [None]:
from cellpose import metrics

model = models.CellposeModel(gpu=True,
                             pretrained_model=new_model_path)

# run model on test images
masks = model.eval(test_data, batch_size=32)[0]

# check performance using ground truth labels
ap = metrics.average_precision(test_labels, masks)[0]
print('')
print(f'>>> average precision at iou threshold 0.5 = {ap[:,0].mean():.3f}')


plot masks

In [None]:
plt.figure(figsize=(12,8), dpi=150)
for k,im in enumerate(test_data):
    img = im.copy()
    plt.subplot(3,len(test_data), k+1)
    img = np.vstack((img, np.zeros_like(img)[:1]))
    img = img.transpose(1,2,0)
    plt.imshow(img)
    plt.axis('off')
    if k==0:
        plt.title('image')

    plt.subplot(3,len(test_data), len(test_data) + k+1)
    plt.imshow(masks[k])
    plt.axis('off')
    if k==0:
        plt.title('predicted labels')

    plt.subplot(3,len(test_data), 2*len(test_data) + k+1)
    plt.imshow(test_labels[k])
    plt.axis('off')
    if k==0:
        plt.title('true labels')
plt.tight_layout()