In [1]:
import os
import os.path as osp
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from cellpose.models import CellposeModel
from cellpose.io import load_train_test_data
from cellpose.metrics import average_precision

  return torch._C._cuda_getDeviceCount() > 0


In [3]:
BASE_PATH = "../Data/Cellpose/Sartorius"
train_dir = osp.join(BASE_PATH, "train")
val_dir = osp.join(BASE_PATH, "val")
test_dir = osp.join(BASE_PATH, "test")

EXP_PATH = "../Experiments"
MODEL = "cellpose"
DATASET = "sartorius"

model_dir = osp.join(osp.join(EXP_PATH, "{}_{}".format(MODEL, DATASET)), "models")

In [4]:
model_file_list = [f for f in os.listdir(model_dir) if f.startswith("cellpose")]
epoch_list = [f.split("_")[-1] for f in os.listdir(model_dir) if f.startswith("cellpose")]

In [5]:
gb_val_csv = pd.read_csv(osp.join(val_dir, "val.csv")).groupby("id")
gb_test_csv = pd.read_csv(osp.join(test_dir, "test.csv")).groupby("id")

output = load_train_test_data(val_dir, test_dir=test_dir, mask_filter='_mask')
val_images, val_labels, image_names_val, test_images, test_labels, image_names_test = output

val_images = [np.concatenate((val_images[i], np.zeros(val_images[i].shape)), axis=2) for i in range(len(val_images))]
test_images = [np.concatenate((test_images[i], np.zeros(test_images[i].shape)), axis=2) for i in range(len(test_images))]
val_labels = [val_labels[i][0].astype(int) for i in range(len(val_labels))]
test_labels = [test_labels[i][0].astype(int) for i in range(len(test_labels))]

val_image_ids = [f.split("/")[-1].split(".")[0] for f in image_names_val]
test_image_ids = [f.split("/")[-1].split(".")[0] for f in image_names_test]

val_cell_type_list = [gb_val_csv.get_group(image_id).reset_index().loc[0, "cell_type"] for image_id in val_image_ids]
test_cell_type_list = [gb_test_csv.get_group(image_id).reset_index().loc[0, "cell_type"] for image_id in test_image_ids]

In [None]:
print(len(val_image_ids))\
print(len(test_image_ids))

In [None]:
if osp.exists(osp.join(model_dir, 'val_eval.csv')):
    os.remove(osp.join(model_dir, 'val_eval.csv'))
with open(osp.join(model_dir, 'val_eval.csv'), 'w') as f:
    eval_head = ["epoch", "image_id", "cell_type", "ap_50", "ap_55", "ap_60", "ap_65", "ap_70", "ap_75", "ap_80", "ap_85", "ap_90", "ap_95", "ap"]
    f.write(','.join(eval_head) + '\n')
if osp.exists(osp.join(model_dir, 'test_eval.csv')):
    os.remove(osp.join(model_dir, 'test_eval.csv'))
with open(osp.join(model_dir, 'test_eval.csv'), 'w') as f:
    eval_head = ["epoch", "image_id", "cell_type", "ap_50", "ap_55", "ap_60", "ap_65", "ap_70", "ap_75", "ap_80", "ap_85", "ap_90", "ap_95", "ap"]
    f.write(','.join(eval_head) + '\n')
        
for model_idx in range(len(model_file_list)):
    print(f">>> Epoch {epoch_list[model_idx]}")
    model_path = osp.join(model_dir, model_file_list[model_idx])
    model = CellposeModel(gpu=True, pretrained_model=model_path,  net_avg=False,
                              diam_mean=17., device=None, residual_on=True, style_on=True, concatenation=False, nchan=2)
    print(f"Loading model {model_path}")
    val_masks, flows, styles =  model.eval(val_images, batch_size=8, diameter=17., channels=[0,0], net_avg=False)
    mean_ap = average_precision(val_labels, val_masks, threshold=[0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95])
    print(f"Val Mean AP {np.mean(mean_ap[0])}")
    for image_idx in range(len(val_image_ids)):
        cell_type = gb_val_csv.get_group(val_image_ids[image_idx]).reset_index().loc[0, "cell_type"]
        with open(osp.join(model_dir, 'val_eval.csv'), 'a') as f:
            log = [epoch_list[model_idx], val_image_ids[image_idx], cell_type] + list(mean_ap[0][image_idx]) + [np.mean(mean_ap[0][image_idx])]
            log = map(str, log)
            f.write(','.join(log) + '\n')
        
    test_masks, flows, styles =  model.eval(test_images, batch_size=8, diameter=17., channels=[0,0], net_avg=False)
    mean_ap = average_precision(test_labels, test_masks, threshold=[0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95])
    print(f"Test Mean AP {np.mean(mean_ap[0])}")
    for image_idx in range(len(test_image_ids)):
        cell_type = gb_test_csv.get_group(test_image_ids[image_idx]).reset_index().loc[0, "cell_type"]
        with open(osp.join(model_dir, 'test_eval.csv'), 'a') as f:
            log = [epoch_list[model_idx], test_image_ids[image_idx], cell_type] + list(mean_ap[0][image_idx]) + [np.mean(mean_ap[0][image_idx])]
            log = map(str, log)
            f.write(','.join(log) + '\n')

>>> Epoch 51
['../Experiments/cellpose_sartorius/models/cellpose_residual_on_style_on_concatenation_off_cellpose_sartorius_2022_07_30_02_15_19.233628_epoch_51']
Loading model ../Experiments/cellpose_sartorius/models/cellpose_residual_on_style_on_concatenation_off_cellpose_sartorius_2022_07_30_02_15_19.233628_epoch_51
Val Mean AP 0.24541544914245605


In [7]:
df_val_eval = pd.read_csv(osp.join(model_dir, 'val_eval.csv'))
df_val_eval

Unnamed: 0,epoch,image_id,cell_type,ap_50,ap_55,ap_60,ap_65,ap_70,ap_75,ap_80,ap_85,ap_90,ap_95,ap
0,51,0ea6df67cc77,cort,0.666667,0.625,0.625,0.547619,0.413043,0.25,0.015625,0.0,0.0,0.0,0.314295
1,51,0eb1d03df587,shsy5y,0.026706,0.005814,0.002899,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.003542
2,51,1d8ea1f865e0,cort,0.813953,0.733333,0.695652,0.695652,0.444444,0.368421,0.3,0.098592,0.0,0.0,0.415005
3,51,2d9fd17da790,astro,0.067568,0.053333,0.053333,0.039474,0.012821,0.012821,0.012821,0.012821,0.0,0.0,0.026499
4,51,3b56cced208e,shsy5y,0.271739,0.26259,0.231579,0.197952,0.166113,0.090062,0.035398,0.002857,0.0,0.0,0.125829
5,51,4cf637b37b8b,cort,0.904762,0.818182,0.818182,0.538462,0.212121,0.081081,0.052632,0.025641,0.0,0.0,0.345106
6,51,5c252798d269,shsy5y,0.343173,0.309353,0.272727,0.238095,0.189542,0.109756,0.037037,0.005525,0.002755,0.0,0.150796
7,51,7d13efbfce6d,cort,0.846154,0.777778,0.777778,0.777778,0.548387,0.263158,0.142857,0.066667,0.0,0.0,0.420056
8,51,7f27bcdc5e5d,cort,0.65,0.65,0.571429,0.5,0.375,0.269231,0.178571,0.03125,0.0,0.0,0.322548
9,51,8d0f8970d171,shsy5y,0.331579,0.324607,0.304124,0.265,0.216346,0.144796,0.081197,0.036885,0.003968,0.0,0.17085


In [None]:
aps = df_val_eval.groupby('epoch').mean()["ap"]
cort_aps = df_val_eval[df_val_eval["cell_type"] == "cort"].groupby('epoch').mean()["ap"]
shsy5y_aps = df_val_eval[df_val_eval["cell_type"] == "shsy5y"].groupby('epoch').mean()["ap"]
astro_aps = df_val_eval[df_val_eval["cell_type"] == "astro"].groupby('epoch').mean()["ap"]
plt.figure(figsize = (8, 6))
plt.plot(aps, label='meanAP')
plt.plot(cort_aps, label='cort')
plt.plot(shsy5y_aps, label='shsy5y')
plt.plot(astro_aps, label='astro')
plt.legend()

In [None]:
model_idx = epoch_list.index("98")
model_path = osp.join(model_dir, model_file_list[model_idx])
model = CellposeModel(gpu=True, pretrained_model=model_path,  net_avg=False,
                          diam_mean=17., device=None, residual_on=True, style_on=True, concatenation=False, nchan=2)
masks, flows, styles =  model.eval(test_images, batch_size=8, diameter=17., channels=[0,0], net_avg=False)

In [None]:
image_idx = 2
plt.figure()
plt.imshow(test_images[image_idx][:,:,0])
plt.figure()
plt.imshow(test_labels[image_idx])
plt.figure()
plt.imshow(masks[image_idx])
plt.show()