# Cellpose training 

**THIS NOTEBOOK REQUIRES A SEPARATE ENVIRONMENT WITH CELLPOSE INSTALLED**

Use this to retrain Cellpose with our data.

This notebook was inspired by the StarDist training notebook.

In [None]:
# --------- REQUIRES A SEPARATE CONDA ENVIRONMENT WITH CELLPOSE INSTALLED --------- #
from cellpose.models import CellposeModel
from glob import glob
from tifffile import imread
import numpy as np
import pathlib as pt

In [None]:
VAL_PERCENT = 0.2
SAVE_NAME ="fold2_cellpose.cellpose"
CELL_MEAN_DIAM = 3.3

path_images = pt.Path.home() / "Desktop/Code/CELLSEG_BENCHMARK/TPH2_mesospim/TRAINING"
X_paths = sorted(glob(str(path_images / '*.tif')))
Y_paths = sorted(glob(str(path_images / 'labels/*.tif')))

In [None]:
def convert_2d(images_array, images_names=None, dtype=np.float32):
    images_2d = []
    images_names_2d = [] if images_names is not None else None
    for i, image in enumerate(images_array):
        for j, slice in enumerate(image):
            images_2d.append(slice.astype(dtype))
            if images_names is not None:
                images_names_2d.append(f"{pt.Path(images_names[i]).stem}_{j}.tif")
    return images_2d, images_names_2d

In [None]:
X = list(map(imread,X_paths))
Y = list(map(imread,Y_paths))
# split X and Y into training and validation sets before 2d conversion
val_idx = int(len(X) * (1 - VAL_PERCENT))
X_trn, X_val = X[:val_idx], X[val_idx:]
Y_trn, Y_val  = Y[:val_idx], Y[val_idx:]
print("Train:")
[print(pt.Path(p).stem) for p in X_paths[:val_idx]]
print("*"*20)
print("Validation:")
[print(pt.Path(p).stem) for p in X_paths[val_idx:]]
# convert to 2d
X_trn_2d, X_trn_2d_paths = convert_2d(X_trn, X_paths[:val_idx])
Y_trn_2d, Y_trn_2d_paths = convert_2d(Y_trn, Y_paths[:val_idx], dtype=np.uint16)
X_val_2d, X_val_2d_paths = convert_2d(X_val, X_paths[val_idx:])
Y_val_2d, Y_val_2d_paths = convert_2d(Y_val, Y_paths[val_idx:], dtype=np.uint16)
assert len(X_trn_2d) == len(Y_trn_2d)
assert len(X_val_2d) == len(Y_val_2d)

In [None]:
print("Parameters :")
print(f"VAL_PERCENT : {VAL_PERCENT}")
print(f"SAVE_NAME : {SAVE_NAME}")
print(f"Path images : {path_images}")
print(f"cell_mean_diam : {CELL_MEAN_DIAM}")

In [None]:
print('number of images (2d): %3d' % len(X_trn_2d+X_val_2d))
print('- training:       %3d' % len(X_trn_2d))
print('- validation:     %3d' % len(X_val_2d))

In [None]:
model = CellposeModel(
        gpu=True,
        pretrained_model=False,
        model_type=None,
        diam_mean=CELL_MEAN_DIAM, # 3.3,
        # nchan=1,
    )

In [None]:
model.train(
    train_data=X_trn_2d,
    train_labels=Y_trn_2d,
    # train_files=X_trn_paths,
    test_data=X_val_2d,
    test_labels=Y_val_2d,
    # test_files=X_val_paths,
    save_path="./",
    save_every=10,
    n_epochs=50,
    channels=[0,0],
    model_name=SAVE_NAME,
)