In [1]:
import pandas as pd
from cellpose import core, io, models, metrics
from cellpose import train
from sklearn.model_selection import KFold

use_GPU = core.use_gpu()
print('>>> GPU activated? %d'%use_GPU)

>>> GPU activated? 1


In [2]:
# ["cyto", "cyto3","nuclei","tissuenet_cp3", "livecell_cp3", "yeast_PhC_cp3", "yeast_BF_cp3", "bact_phase_cp3", "bact_fluor_cp3", "deepbacs_cp3", "scratch"]
initial_model = "cyto3" 

n_epochs =  350
learning_rate = 0.1
weight_decay = 0.0001

chan = 0
chan2 = 0

if initial_model=='scratch':
  initial_model = None

In [3]:
dataset_dir = 'datasets/yh2ax'
model_name = 'yh2ax_cyto3'

results_file = 'results/results.csv'

In [4]:
def precision(tp, fp):
    return tp/(tp+fp) if tp > 0 else 0
def recall(tp, fn):
    return tp/(tp+fn) if tp > 0 else 0

images, labels, _ = io.load_images_labels(dataset_dir)

In [5]:
def cross_validate_cellpose(images, masks, save_file, n_splits=5):
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
    
    for fold, (train_idx, test_idx) in enumerate(kf.split(images)):
        print(f"Starting fold {fold + 1}/{n_splits}...")
        
        # print(f'Train idx:{train_idx}, test idx:{test_idx}\n')
        # continue

        # Splitting the dataset
        train_images, test_images = [images[i] for i in train_idx], [images[i] for i in test_idx]
        train_masks, test_masks = [masks[i] for i in train_idx], [masks[i] for i in test_idx]

        # Initialize Cellpose model
        model = models.CellposeModel(gpu=use_GPU, model_type=initial_model)

        train.train_seg(model.net, train_data=train_images, 
                              train_labels=train_masks, 
                              test_data=test_images,
                              test_labels=test_masks,
                              channels=[chan, chan2], 
                              save_path=dataset_dir, 
                              n_epochs=n_epochs,
                              learning_rate=learning_rate, 
                              weight_decay=weight_decay, 
                              SGD=True,
                              model_name=model_name+str(fold),
                              min_train_masks=1,
                              rescale=False,
                              normalize={'normalize': True, 'percentile':[1,97]})
        
        diam_labels = model.net.diam_labels.item()

        eval_masks = model.eval(test_images, 
                    channels=[chan, chan2],
                    diameter=diam_labels, min_size=1,
                    normalize={'normalize': True, 'percentile':[1,97]})[0]

        ap_all, tp_all, fp_all, fn_all = metrics.average_precision(test_masks, eval_masks, threshold=0.1)

        tp = [int(tp_all[:,i].sum()) for i in range(tp_all.shape[1])]
        fp = [int(fp_all[:,i].sum()) for i in range(fp_all.shape[1])]
        fn = [int(fn_all[:,i].sum()) for i in range(fn_all.shape[1])]
        ap = [float(ap_all[:,i].mean()) for i in range(ap_all.shape[1])]

        prec = precision(tp[0], fp[0])
        rec = recall(tp[0], fn[0])

        data = {
            'model': [model_name+str(fold)],
            'tp': [tp[0]],
            'fp': [fp[0]],
            'fn': [fn[0]],
            'precision': [prec],
            'recall': [rec],
            'accuracy': [ap[0]]
        }

        df = pd.DataFrame(data)

        df1 = pd.read_csv(save_file)

        merged_df = pd.concat([df1, df], ignore_index=True)
        merged_df = merged_df.sort_values(by='model', ascending=True)

        merged_df.to_csv(results_file, index=False)

        print(f"Fold {fold + 1} completed.")


In [6]:
cross_validate_cellpose(images, labels, save_file=results_file, n_splits=5)

Starting fold 1/5...


100%|██████████| 231/231 [00:02<00:00, 101.35it/s]
100%|██████████| 58/58 [00:00<00:00, 198.77it/s]
100%|██████████| 231/231 [00:00<00:00, 7074.21it/s]
100%|██████████| 58/58 [00:00<00:00, 8574.89it/s]


Fold 1 completed.
Starting fold 2/5...


100%|██████████| 231/231 [00:01<00:00, 192.55it/s]
100%|██████████| 58/58 [00:00<00:00, 187.37it/s]
100%|██████████| 231/231 [00:00<00:00, 7422.52it/s]
100%|██████████| 58/58 [00:00<00:00, 5946.02it/s]


Fold 2 completed.
Starting fold 3/5...


100%|██████████| 231/231 [00:01<00:00, 198.96it/s]
100%|██████████| 58/58 [00:00<00:00, 211.92it/s]
100%|██████████| 231/231 [00:00<00:00, 8634.41it/s]
100%|██████████| 58/58 [00:00<00:00, 5898.73it/s]


Fold 3 completed.
Starting fold 4/5...


100%|██████████| 231/231 [00:01<00:00, 207.34it/s]
100%|██████████| 58/58 [00:00<00:00, 189.10it/s]
100%|██████████| 231/231 [00:00<00:00, 8101.78it/s]
100%|██████████| 58/58 [00:00<00:00, 7249.88it/s]


Fold 4 completed.
Starting fold 5/5...


100%|██████████| 232/232 [00:01<00:00, 188.35it/s]
100%|██████████| 57/57 [00:00<00:00, 184.95it/s]
100%|██████████| 232/232 [00:00<00:00, 8147.08it/s]
100%|██████████| 57/57 [00:00<00:00, 8121.04it/s]


Fold 5 completed.
