# 🦒 Fine-tuning Cellpose with BioEngine ⚙️☁️

## Installation and module imports

In [1]:
try:
    # For pyodide in the browser
    import micropip

    await micropip.install(["pyodide-http", "hypha-rpc", "httpx"])

    # 2. Patch requests
    import pyodide_http

    pyodide_http.patch_all()  # Patch all libraries
except ImportError:
    # For native python with pip
    import subprocess

    subprocess.call(["pip", "install", "hypha-rpc", "kaibu-utils", "matplotlib"])

import os
from pathlib import Path

import httpx
from hypha_rpc import connect_to_server, login



In [2]:
# Server URL: For this demo we will use the hypha.aicell.io server
SERVER_URL = "https://hypha.aicell.io"

### Connect to the server

In [3]:
token = await login({"server_url": SERVER_URL})

server = await connect_to_server(
    {"server_url": SERVER_URL, "token": token}
)
workspace = server.config.workspace

print(f"Connected to workspace: {workspace}")

artifact_manager = await server.get_service("public/artifact-manager")

Please open your browser and login at https://hypha.aicell.io/public/apps/hypha-login/?key=kfVg2CU76xLLBjQL7edrbh
Connected to workspace: ws-user-github|49943582


### Access the BioEngine deployments

A public BioEngine instance is available with the service ID `bioimage-io/bioengine-apps`

In [4]:
PUBLIC_BIOENGINE = "bioimage-io/bioengine-apps"

bioengine = await server.get_service(PUBLIC_BIOENGINE)

### Create the dataset collection (if not already created)

In [5]:
# Check if the data collection exists
collection_id = f"{workspace}/bioimageio-colab"

try:
    await artifact_manager.list(collection_id)
except Exception as e:
    expected_error = f'KeyError: "Artifact with ID \'{collection_id}\' does not exist."'
    if str(e).strip().endswith(expected_error):
        print(f"Collection '{collection_id}' does not exist. Creating it.")

    collection_manifest = {
        "name": "BioImage.IO Colab",
        "description": "A collection of annotated images from BioImage.IO Colab.",
    }
    collection = await artifact_manager.create(
        alias=collection_id,
        type="collection",
        manifest=collection_manifest,
        config={"permissions": {"*": "r", "@": "r+"}}
    )
    print(f"BioImage.IO Colab data collection created with ID: {collection.id}")

### Upload the annotated dataset to the collection

In [6]:
# Load the dataset content
data_path = Path.cwd().parent / "data" / "hpa_demo" / "data.zip"
dataset_content = data_path.read_bytes()

# Create or update the dataset artifact
dataset_manifest = {
    "name": "HPA Demo",
    "description": "An annotated dataset for Cellpose finetuning",
    "type": "data",
}
data_artifact_alias = "hpa-demo"

try:
    # Edit the existing deployment and stage it for review
    artifact = await artifact_manager.edit(
        artifact_id=f"{workspace}/{data_artifact_alias}",
        manifest=dataset_manifest,
        type=dataset_manifest["type"],
        version="stage",
    )
except:
    # If the artifact does not exist, create it
    artifact = await artifact_manager.create(
        alias=data_artifact_alias,
        parent_id=collection_id,
        manifest=dataset_manifest,
        type=dataset_manifest["type"],
        version="stage",
    )
    print(f"Artifact created with ID: {artifact.id}")

# Upload manifest.yaml
manifest_url = await artifact_manager.put_file(artifact.id, file_path="manifest.yaml")
async with httpx.AsyncClient(timeout=30) as client:
    response = await client.put(manifest_url, data=dataset_manifest)
    response.raise_for_status()
    print(f"Uploaded manifest.yaml to artifact")

# Upload the dataset content as a zip file
data_url = await artifact_manager.put_file(artifact.id, file_path="data.zip")
async with httpx.AsyncClient(timeout=30) as client:
    response = await client.put(data_url, data=dataset_content)
    response.raise_for_status()
    print(f"Uploaded data.zip to artifact")

# Commit the artifact
await artifact_manager.commit(
    artifact_id=artifact.id,
    version="new",
)
print(f"Committed artifact with ID: {artifact.id}")

Uploaded manifest.yaml to artifact
Uploaded data.zip to artifact
Committed artifact with ID: ws-user-github|49943582/hpa-demo


### Prepare the data for training

Create an artifact for the fine-tuned Cellpose model

In [12]:
model_manifest = {
    "name": "Finetuned Cellpose model",
    "description": "Finetuned model for Cellpose cyto3",
    "type": "model",
}
model_artifact_alias = "cellpose-cyto3-hpa-finetuned"

try:
    model_artifact = await artifact_manager.create(
        alias=model_artifact_alias,
        parent_id=collection_id,
        manifest=model_manifest,
        type=model_manifest["type"],
        version="stage",
    )
except:
    model_artifact_id = f"{workspace}/{model_artifact_alias}"
    answer = input(
        f"Artifact {model_artifact_id} already exists. Do you want to overwrite it? (y/n): "
    )
    if answer.lower() != "y":
        raise RuntimeError(
            f"Artifact {model_artifact_id} already exists and will not be overwritten."
        )

    # Overwrite the existing artifact
    model_artifact = await artifact_manager.edit(
        artifact_id=f"{workspace}/{model_artifact_alias}",
        manifest=model_manifest,
        type=model_manifest["type"],
        version="stage",
    )

In [13]:
# Create presigned URLs for data download and model upload
data_download_url = await artifact_manager.get_file(
    artifact_id=f"{workspace}/{data_artifact_alias}", file_path="data.zip"
)

model_upload_url = await artifact_manager.put_file(
    model_artifact.id, file_path=model_artifact_alias.replace("-", "_")
)

data = {
    "data_download_url": data_download_url,
    "model_upload_url": model_upload_url,
    "initial_model": "cyto3",
}

### Run the Cellpose fine-tuning

In [14]:
result = await bioengine.bioimage_io_cellpose_finetuning.train(data=data)
await artifact_manager.commit(artifact_id=model_artifact.id)
print(f"Committed artifact with ID: {model_artifact.id}")

print(f"Average precision at iou threshold 0.5: {result['final_average_precision']['0.5']:.3f}")

Committed artifact with ID: ws-user-github|49943582/cellpose-cyto3-hpa-finetuned
Average precision at iou threshold 0.5: 0.534


In [None]:
import tempfile
import httpx
import zipfile
from pathlib import Path

# Create a temporary directory to save the downloaded file
data_dir = tempfile.mkdtemp()

# Define the path to save the downloaded zip file
zip_file_path = Path(data_dir) / "data.zip"

# Download the zip file
download_url = await artifact_manager.get_file(
    artifact_id="hpa-demo", file_path="data.zip"
)

async with httpx.AsyncClient(timeout=30) as client:
    response = await client.get(download_url)
    response.raise_for_status()
    with open(zip_file_path, "wb") as f:
        f.write(response.content)

# Unzip the downloaded file
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
    zip_ref.extractall(data_dir)

print(f"Data downloaded and extracted to: {data_dir}")

In [None]:
import tempfile
from pathlib import Path
import numpy as np
from cellpose import models, train, io, metrics


image_dir = Path(data_dir) / "hpa_demo"
annotations_dir = image_dir / "annotations"

# List to hold pairs of image and corresponding annotation masks
image_annotation_pairs = []

# Get list of all images and annotations
annotation_files = list(annotations_dir.glob("*.tif"))

# Iterate through each annotation file
for annotation_file in annotation_files:
    annotation_name = annotation_file.name
    image_name = annotation_name.split("_mask_")[0]
    image_file = image_dir / f"{image_name}.tif"

    image_annotation_pairs.append((image_file, annotation_file))


# Print the number of annotations
print(f"Number of annotations: {len(image_annotation_pairs)}")

In [None]:
import matplotlib.pyplot as plt
from tifffile import imread


assert len(image_annotation_pairs) >= 5


# Plot several random annotations
choices = np.random.choice(len(image_annotation_pairs), 5, replace=False)
plt.figure(figsize=(15, 6))

for i in range(5):
    plt.subplot(2, 5, i + 1)
    img = imread(image_annotation_pairs[choices[i]][0])
    plt.imshow(img.transpose(1, 2, 0))
    plt.title(f"{image_annotation_pairs[choices[i]][0].stem}")
    plt.axis("off")

    plt.subplot(2, 5, i + 6)
    mask = imread(image_annotation_pairs[choices[i]][1])
    plt.imshow(mask)
    plt.axis("off")

plt.tight_layout()
plt.show()

In [None]:
# Get all indices of the list
all_indices = np.arange(len(image_annotation_pairs))

# Define the split ratio (e.g., 80% train, 20% test)
train_ratio = 0.8
train_size = int(len(all_indices) * train_ratio)

# Randomly shuffle and split indices
np.random.shuffle(all_indices)
train_indices = all_indices[:train_size]
test_indices = all_indices[train_size:]

# Create train and test splits
train_files = [image_annotation_pairs[i][0] for i in train_indices]
train_labels_files = [image_annotation_pairs[i][1] for i in train_indices]
test_files = [image_annotation_pairs[i][0] for i in test_indices]
test_labels_files = [image_annotation_pairs[i][1] for i in test_indices]

In [None]:
initial_model = "cyto3"  # ["cyto", "cyto3", "nuclei", "tissuenet_cp3", "livecell_cp3", "yeast_PhC_cp3", "yeast_BF_cp3", "bact_phase_cp3", "bact_fluor_cp3", "deepbacs_cp3", "None"]
output_model_name = "CP_HPA"

channels_lut = {
    "Grayscale": 0,
    "Red": 1,
    "Green": 2,
    "Blue": 3,
}

channels = [
    channels_lut["Grayscale"],  # Channel to use for training
    channels_lut["Grayscale"],  # Second training channel (if applicable)
]

n_epochs = 10
learning_rate = 0.000001
weight_decay = 0.0001

save_path = tempfile.mkdtemp()

In [None]:
# start logger (to see training across epochs)
logger = io.logger_setup()

# DEFINE CELLPOSE MODEL (without size model)
model = models.CellposeModel(gpu=True, model_type=initial_model)

new_model_path = train.train_seg(
    model.net,
    train_files=train_files,
    train_labels_files=train_labels_files,
    test_files=test_files,
    test_labels_files=test_labels_files,
    channels=channels,
    save_path=save_path,
    n_epochs=n_epochs,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    SGD=True,
    nimg_per_epoch=1,
    model_name=output_model_name,
    min_train_masks=1,
)

In [None]:
new_model_path[0]

In [None]:
# get files (during training, test_data is transformed so we will load it again)
test_data = [imread(image_path) for image_path in test_files[:2]]
test_labels = [imread(image_path) for image_path in test_labels_files[:2]]

# diameter of labels in training images
# use model diameter if user diameter is 0
diameter = 0
diameter = model.diam_labels if diameter == 0 else diameter
diam_labels = model.diam_labels.item()

# run model on test images
masks = model.eval(test_data, channels=channels, diameter=diam_labels)[0]

# check performance using ground truth labels
ap = metrics.average_precision(test_labels, masks)[0]
print("")
print(f">>> average precision at iou threshold 0.5 = {ap[:,0].mean():.3f}")

In [None]:
ap = metrics.average_precision(test_labels, masks, threshold=[0.5, 0.75, 0.9])[0]

In [None]:
{t: p for t, p in zip([0.5, 0.75, 0.9], ap.mean(axis=0))}

In [None]:
plt.figure(figsize=(9, 6))

for i in range(2):  # Two rows
    # Plot the image
    plt.subplot(2, 3, i * 3 + 1)
    plt.imshow(test_data[i].transpose(1, 2, 0))
    plt.axis("off")
    if i == 0:
        plt.title("Image")

    # Plot the predicted labels
    plt.subplot(2, 3, i * 3 + 2)
    plt.imshow(masks[i])
    plt.axis("off")
    if i == 0:
        plt.title("Predicted Labels")

    # Plot the true labels
    plt.subplot(2, 3, i * 3 + 3)
    plt.imshow(test_labels[i])
    plt.axis("off")
    if i == 0:
        plt.title("True Labels")

plt.tight_layout()
plt.show()

In [None]:
import os

os.remove(save_path)