In [None]:
# Import necessary packages
import numpy as np
import matplotlib.pyplot as plt

from flax.training import checkpoints
from jax import jit, random
from skimage import segmentation

from cellori.applications.cyto import data
from cellori.applications.cyto.model import CelloriCytoModel
from cellori.utils import masks, metrics
from cellpose.models import Cellpose

In [None]:
# Create models
cellori_model = CelloriCytoModel()
cellpose_model = Cellpose()

# Load Cellori parameters
variables = checkpoints.restore_checkpoint('cellori_model', None)

In [None]:
# Load test dataset
test_ds = data.load_dataset('test', use_gpu=True)

# Transform test dataset
rng = random.PRNGKey(42)
transformed_test_ds = data.transform_dataset(test_ds, rng)

In [None]:
# Create list for masks
cellori_masks = []
cellpose_masks = []

# Run models
for image in transformed_test_ds['image']:
    grads, semantic = jit(cellori_model.apply, static_argnums=2)(variables, image[None, :, :, :2], False)
    grads = np.array(np.moveaxis(grads[0], -1, 0))
    cellori_masks.append(masks.compute_masks_dynamics(grads, np.array(semantic[0,:,:,0]))[0])
    cellpose_masks.append(cellpose_model.eval(image, channels=[2, 1]))

In [None]:
# Generate AP curve
thresholds = np.linspace(0, 1, 101)
cellori_mean_aps = []
cellpose_mean_aps = []

for threshold in thresholds:
    cellori_aps = []
    cellpose_aps = []
    for cellori_mask, cellpose_mask, true_mask in zip(cellori_masks, cellpose_masks, transformed_test_ds['mask']):
        true_mask_cleared = segmentation.clear_border(true_mask[:, :, 0].astype(int))
        cellori_mask_cleared = segmentation.clear_border(cellori_mask)
        cellpose_mask_cleared = segmentation.clear_border(cellpose_mask[0])
        cellori_aps.append(metrics.MaskMetrics(true_mask_cleared, cellori_mask_cleared).calculate('AP', 'f1', threshold))
        cellpose_aps.append(metrics.MaskMetrics(true_mask_cleared, cellpose_mask_cleared).calculate('AP', 'f1', threshold))
    cellori_aps = np.array(cellori_aps)[~np.isnan(cellori_aps)]
    cellpose_aps = np.array(cellpose_aps)[~np.isnan(cellpose_aps)]
    cellori_mean_aps.append(np.mean(cellori_aps))
    cellpose_mean_aps.append(np.mean(cellpose_aps))

In [None]:
# Plot AP curve
fig, ax = plt.subplots(dpi=300)
ax.plot(thresholds, cellori_mean_aps, label='Cellori')
ax.plot(thresholds, cellpose_mean_aps, label='Cellpose')
ax.set_title('Model Benchmarking')
ax.set_xlabel('IoU Threshold')
ax.set_ylabel('Average Precision')
ax.legend()