this notebook runs both cyto3 and cellpose-sam so uses cellpose==3.1.1.2 with the cellpose-sam net hacked in

(see example env setup in benchmark_all_sam.ipynb)

In [None]:
import numpy as np
from cellpose import io, metrics, models, utils, transforms, denoise
import time
from tqdm import trange
from pathlib import Path
from natsort import natsorted
import tifffile
import matplotlib.pyplot as plt
import benchmarks

files, imgs, masks_true = benchmarks.load_dataset("cyto2")
diam_true = [utils.diameters(m)[0] for m in masks_true]

## size

In [None]:
#from cellSAM import cellsam_pipeline
from cellSAM import segment_cellular_image
from cellpose import resnet_torch
import cv2
import torch
from train_subsets import TransformerMP
io.logger_setup()

device = torch.device("cuda")

ps = 8
backbone = "vit_l"
net = TransformerMP(ps=ps, backbone=backbone).to(device)
net.load_model("models/cpsam8_0_2100_8_402175188", strict=False, multigpu=False)

model = models.CellposeModel(gpu=True, nchan=3)
net.eval()
model.net = net

cp_model = models.CellposeModel(gpu=True, model_type="cyto3")


aps = [[], [], []]
masks_preds = []
for sz in [10, 15, 30, 60, 90]:
    diameters = diam_true.copy() * (30. / sz)
    imgs_rsz = [transforms.resize_image(imgs[i].transpose(1,2,0), rsz=30./diameters[i]).transpose(2,0,1) for i in range(len(imgs))]
    masks_true_rsz = [transforms.resize_image(masks_true[i], rsz=30./diameters[i], no_channels=True, interpolation=cv2.INTER_NEAREST) for i in range(len(imgs))]

    for j in range(1):
        if j==0:
            masks_pred, flows, styles = model.eval(imgs_rsz, diameter=30., channels=None, augment=False,
                                            bsize=256, tile_overlap=0.1, batch_size=64,
                                            flow_threshold=0.4, cellprob_threshold=0, normalize=False)
        elif j==1:
            masks_pred, flows, styles = cp_model.eval(imgs_rsz, diameter=30., channels=[2,3],
                                            bsize=224, tile_overlap=0.5, batch_size=64, augment=True,
                                            flow_threshold=0.4, cellprob_threshold=0, normalize=False)
        else:
            if sz!=90:
                masks_pred = []
                bsize = 1024
                for i in trange(len(imgs_rsz)):
                    img = imgs_rsz[i][[0,2,1]].copy()
                    Ly, Lx = img.shape[1:]
                    Lyr = bsize if Ly > bsize and Ly > Lx else Ly
                    Lxr = bsize if Lx > bsize and Lx >= Ly else Lx
                    Lxr = int(np.round(bsize * (Lx / Ly))) if Ly > Lx and Lyr==bsize else Lxr
                    Lyr = int(np.round(bsize * (Ly / Lx))) if Lx >= Ly and Lxr==bsize else Lyr
                    if Lyr != Ly or Lxr != Lx:
                        img = cv2.resize(img.transpose(1, 2, 0), (Lxr, Lyr), interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1)
                    if Lyr < bsize or Lxr < bsize:
                        padyx = [(int(np.floor((bsize-Lyr)/2)), int(np.ceil((bsize-Lyr)/2))),
                                (int(np.floor((bsize-Lxr)/2)), int(np.ceil((bsize-Lxr)/2)))]
                        img = np.pad(img, ((0,0), padyx[0], padyx[1]), mode='constant')
                    else:
                        padyx = [(0, 0), (0, 0)]
                    try:
                        masks, _, _ = segment_cellular_image(img, device='cuda')
                    except:
                        masks = np.zeros((bsize, bsize), dtype="uint16")
                    masks = masks[padyx[0][0]:bsize-padyx[0][1], padyx[1][0]:bsize-padyx[1][1]]
                    if Lyr != Ly or Lxr != Lx:
                        masks = cv2.resize(masks, (Lx, Ly), interpolation=cv2.INTER_NEAREST)
                    masks_pred.append(masks)

        ap, tp, fp, fn = metrics.average_precision(masks_true_rsz, masks_pred)
        if sz==90 and j==2:
            ap *= np.nan
        else:
            print(ap.mean(axis=0))
        aps[j].append(ap)
        if j==0:
            masks_preds.append(masks_pred)

In [None]:
import matplotlib.pyplot as plt
aps = np.array(aps)
plt.plot(aps[:,:,:,0].mean(axis=-1).T)
ax = plt.gca()
ax.set_xticks(np.arange(5))
ax.set_xticklabels(["10", "15", "30", "60", "90"])

In [None]:
np.save("size_invariance.npy", {"aps": aps, "masks_preds": masks_preds})

## noise

In [None]:
from cellSAM import segment_cellular_image
from cellpose import resnet_torch
import cv2
import torch
io.logger_setup()

device = torch.device("cuda")

ps = 8
backbone = "vit_l"
net = TransformerMP(ps=ps, backbone=backbone).to(device)
net.load_model("models/cpsam8_0_2100_8_402175188", strict=False, multigpu=False)

model = models.CellposeModel(gpu=True, nchan=3)
net.eval()
model.net = net

cp_model = models.CellposeModel(gpu=True, model_type="cyto3")


nstr = {"poisson": "denoise", "blur": "deblur", "downsample": "upsample", "aniso": "aniso"}
for ii, noise_type in enumerate(["poisson", "blur", "downsample", "aniso"]):
    #if noise_type != "downsample":
    #    continue
    masks_preds = [[], [], []]
    dat = np.load(f"{noise_type}_invariance.npy", allow_pickle=True).item()
    aps = dat["aps"]
    masks_preds = dat["masks_preds"]
    #aps = np.zeros((4, 3, len(imgs), 3))
    mstr = "cyto3" if noise_type!="aniso" else "cyto2"
    dn_model = denoise.DenoiseModel(gpu=True, model_type=f"{nstr[noise_type]}_{mstr}", chan2=True)
    print(noise_type)
    if noise_type=="poisson":
        param = np.array([5, 2.5, 0.5])
    elif noise_type=="blur":
        param = np.array([2, 4, 8])# 48])
    elif noise_type=="downsample":
        param = np.array([2, 5, 10])
    elif noise_type=="aniso":
        param = np.array([2, 6, 12])
    print(param)
    ap = np.zeros((len(imgs), len(param)))

    denoise.deterministic()
    importlib.reload(denoise)
    for k in range(len(param)):
        for i in trange(len(imgs)):
            img = np.maximum(0, imgs[i].copy())
            if noise_type=="poisson":
                params = {"poisson": 1.0, "blur": 0.0, "downsample": 0.0, "pscale": param[k]}
            elif noise_type=="blur":
                params = {"poisson": 1.0, "pscale": 120., "blur": 1.0, "downsample": 0.0,
                            "sigma0": param[k], "sigma1": param[k]}
            elif noise_type=="downsample":
                params = {"poisson": 0.0, "pscale": 0., "blur": 1.0, "downsample": 1.0, "ds": param[k],
                            "sigma0": param[k]/2, "sigma1": param[k]/2}
            else:
                params = {"poisson": 0.0, "pscale": 0., "blur": 1.0, "downsample": 1.0, "ds": param[k],
                            "sigma0": param[k]/2, "sigma1": param[k]/2*0, "iso": False}
            img = denoise.add_noise(torch.from_numpy(img).unsqueeze(0),
                                    **params).cpu().numpy().squeeze()

            for j in range(1):
                if j==0:
                    masks_pred0, flows, styles = model.eval(img, diameter=30., channels=None, augment=False,
                                                    bsize=256, tile_overlap=0.1, batch_size=64,
                                                    flow_threshold=0.4, cellprob_threshold=0, normalize=False)
                    masks_preds[k].append(masks_pred0)
                elif j==1:
                    masks_pred0, flows, styles = cp_model.eval(img, diameter=diam_true[i], channels=[2,3],
                                                    bsize=224, tile_overlap=0.5, batch_size=64, augment=True,
                                                    flow_threshold=0.4, cellprob_threshold=0, normalize=False)
                elif j==2:
                    if noise_type=="downsample" or noise_type=="aniso":
                        img_rsz = transforms.resize_image(img.transpose(1,2,0).copy(), rsz=30./diam_true[i]).transpose(2,0,1)
                    else:
                        img_rsz = img.copy()
                    img_dn = dn_model.eval(img_rsz[[1,2]], diameter=None, channels=[1,2], channel_axis=0)
                    masks_pred0, flows, styles = cp_model.eval(img_dn, diameter=diam_true[i] if noise_type!="downsample" and noise_type!="aniso" else None,
                                                               channels=[1,2],
                                                    bsize=224, tile_overlap=0.5, batch_size=64, augment=True,
                                                    flow_threshold=0.4, cellprob_threshold=0, normalize=False)
                    masks_pred0 = transforms.resize_image(masks_pred0, Ly=img.shape[1], Lx=img.shape[2], no_channels=True, interpolation=cv2.INTER_NEAREST)
                    #plt.imshow(masks_pred0)
                    #plt.show()
                else:
                    bsize = 512
                    img = img[[0,2,1]].copy()
                    Ly, Lx = img.shape[1:]
                    Lyr = bsize if Ly > bsize and Ly > Lx else Ly
                    Lxr = bsize if Lx > bsize and Lx >= Ly else Lx
                    Lxr = int(np.round(bsize * (Lx / Ly))) if Ly > Lx and Lyr==bsize else Lxr
                    Lyr = int(np.round(bsize * (Ly / Lx))) if Lx >= Ly and Lxr==bsize else Lyr
                    if Lyr != Ly or Lxr != Lx:
                        img = cv2.resize(img.transpose(1, 2, 0), (Lxr, Lyr), interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1)
                    if Lyr < bsize or Lxr < bsize:
                        padyx = [(int(np.floor((bsize-Lyr)/2)), int(np.ceil((bsize-Lyr)/2))),
                                (int(np.floor((bsize-Lxr)/2)), int(np.ceil((bsize-Lxr)/2)))]
                        img = np.pad(img, ((0,0), padyx[0], padyx[1]), mode='constant')
                    else:
                        padyx = [(0, 0), (0, 0)]
                    try:
                        masks, _, _ = segment_cellular_image(img, device='cuda')
                    except:
                        masks = np.zeros((bsize, bsize), dtype="uint16")
                    masks = masks[padyx[0][0]:bsize-padyx[0][1], padyx[1][0]:bsize-padyx[1][1]]
                    if Lyr != Ly or Lxr != Lx:
                           masks = cv2.resize(masks, (Lx, Ly), interpolation=cv2.INTER_NEAREST)
                    masks_pred0 = masks
                ap = metrics.average_precision([masks_true[i]], [masks_pred0])[0]
                aps[j,k,i] = ap
        print(aps[:,k,:,0].mean(axis=-1))

    #np.save(f"{noise_type}_invariance.npy", {"aps": aps, "masks_preds": masks_preds})


## color

In [None]:
device = torch.device("cuda")

ps = 8
backbone = "vit_l"
net = TransformerMP(ps=ps, backbone=backbone).to(device)
net.load_model("models/cpsam8_0_2100_8_402175188", strict=False, multigpu=False)

model = models.CellposeModel(gpu=True, nchan=3)
net.eval()
model.net = net

irgb = [[0,1,2], [2,0,1], [1,2,0], 'random']
rgb_title = ['RGB', 'BRG', 'GBR', 'Random \n each']

aps = []
masks_preds = []
test_data = []
for i in range(4):
    rgb = irgb[i]

    np.random.seed(42)
    test_data_copy = imgs.copy()

    if rgb == 'random':
        for j in range(len(test_data_copy)):
            iswap = np.random.permutation(3)
            test_data_copy[j] = test_data_copy[j][iswap]
    elif rgb is not None:
        for j in range(len(test_data_copy)):
            test_data_copy[j] = test_data_copy[j][rgb]

    masks_pred, flows, styles = model.eval(test_data_copy, normalize = False, tile_overlap = 0.1,
                            bsize = 256,  diameter= None, #1. * diam_test[ind_im],
                            augment = False, channels=None, niter = None, batch_size = 64)

    app0 = metrics.average_precision(masks_true, masks_pred)[0]
    aps.append(app0)
    masks_preds.append(masks_pred)
    test_data.append(test_data_copy)

aps = np.array(aps)

In [None]:
np.save(f"color_invariance.npy", {"aps": aps, "masks_preds": masks_preds, "test_data": test_data})

In [None]:
plt.plot(aps[:,:,0].mean(axis=-1))