Cellpose Training arena

This notebook includes code to train models and collect data to crossvalidate across multiple hyperparams.




In [1]:
!pip install "opencv-python-headless<4.3"
!pip install cellpose
!pip install stardist #to access matching_dataset


Collecting opencv-python-headless<4.3
  Downloading opencv_python_headless-3.4.18.65-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (45.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.7/45.7 MB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: opencv-python-headless
  Attempting uninstall: opencv-python-headless
    Found existing installation: opencv-python-headless 4.8.1.78
    Uninstalling opencv-python-headless-4.8.1.78:
      Successfully uninstalled opencv-python-headless-4.8.1.78
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
albumentations 1.3.1 requires opencv-python-headless>=4.1.1, but you have opencv-python-headless 3.4.18.65 which is incompatible.
qudida 0.0.4 requires opencv-python-headless>=4.0.1, but you have opencv-python-headless 3.4.18.65 which is incompatible.[0m[31m
[0mSuc

In [2]:
!nvcc --version
!nvidia-smi

import os, shutil
import numpy as np
import matplotlib.pyplot as plt
from cellpose import core, utils, io, models, metrics
from glob import glob
import pandas as pd

use_GPU = core.use_gpu()
yn = ['NO', 'YES']
print(f'>>> GPU activated? {yn[use_GPU]}')


nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Aug_15_22:02:13_PDT_2023
Cuda compilation tools, release 12.2, V12.2.140
Build cuda_12.2.r12.2/compiler.33191640_0
Sat Dec 16 20:59:42 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   48C    P8              11W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                      

In [9]:
# these are the training images from human in the loop workflow
train_dir = "/content/drive/MyDrive/MLP2/traindataHIL"
test_dir = None
# images to validate
dir = "/content/drive/MyDrive/MLP2/cp_validate_focus"
files = io.get_image_files(dir, '_mask')
test_data = [io.imread(f) for f in files]
maskdir = "/content/drive/MyDrive/MLP2/cp_validate_focus_masks"
gtfiles = io.get_image_files(maskdir, '')
print(gtfiles)
test_labels = [io.imread(m) for m in gtfiles]
#Define where the patch file will be saved
base = "/content"

# model name and path
from cellpose import models
from stardist.matching import matching, matching_dataset
initial_base_model = "cyto"
model_name = "CP_tissuenet"


#Constant model params
weight_decay = 0.0001
channels =[0, 0]

#Cross Validation parameters
n_epochs =  [10, 100, 250, 500]
learning_rates = [0.2, 0.1, 0.01]
#this should change on every iter so we log information
model_name = "CP_tissuenet"

# start logger (to see training across epochs)/ need to manually parse the output -.-
# as the logging happens within the training and there are no fn to retrieve this data
logger = io.logger_setup() #CANT SEE RUNTIME TERMINAL !!

# Initialize a list to store results
results = []


for l_r in learning_rates:
  for n_e in n_epochs:
    model_name = f"CP_focus_lr{l_r}_epochs{n_e}"
    # get training files
    output = io.load_train_test_data(train_dir, test_dir, mask_filter='_seg.npy')
    train_data, train_labels, _, _, _, _ = output

    #here we check that no model with the same name already exist, if so delete
    model_path = train_dir + 'models/'
    if os.path.exists(model_path+'/'+model_name):
      print("!! WARNING: "+model_name+" already exists and will be deleted in the following cell !!")
    #now let's train!
    #define initial base model - cyto
    model = models.CellposeModel(gpu=use_GPU, model_type=initial_base_model)

    print(f"#### TRAINING: {model_name} LR: {l_r} epochs: {n_e} START###")

    #TRAIN!
    new_model_path = model.train(train_data, train_labels,
                              test_data=test_data,
                              test_labels=test_labels,
                              channels=channels,
                              save_path=train_dir,
                              n_epochs=n_e,
                              learning_rate=l_r,
                              weight_decay=weight_decay,
                              nimg_per_epoch=8,
                              model_name=model_name)


    print(f"#### TRAINING: {model_name} LR: {l_r}  empochs: {n_e}END###")
    print(f"#### VALIDATING: {model_name} LR: {l_r} empochs: {n_e}START###")

    # get files (during training, test_data is transformed so we will load it again)
    files = io.get_image_files(dir, '_mask')
    test_data = [io.imread(f) for f in files]
    gtfiles = io.get_image_files(maskdir, '')
    test_labels = [io.imread(m) for m in gtfiles]
    # diameter of labels in training images
    diam_labels = model.diam_labels.copy()
    # run model on test images
    masks = model.eval(test_data,
                      channels=channels,
                      diameter=diam_labels)[0]

    # evaluate model with ground truths and different taus/thresholds

    # Evaluate the model using average_precision
    taus = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    ap_values, tp, fp, fn = metrics.average_precision(test_labels, masks, threshold=taus)

    # Evaluate the model using matching_dataset
    all_stats = [matching_dataset(test_labels, masks, thresh=t, show_progress=False) for t in taus]

    # Store the results for each threshold
    for idx, t in enumerate(taus):
        results.append({
            'learning_rate': l_r,
            'num_epochs': n_e,
            'model_name': model_name,
            'threshold': t,
            'average_precision': ap_values[:, idx].mean(),
            'true_positives': tp[:, idx].sum(),
            'false_positives': fp[:, idx].sum(),
            'false_negatives': fn[:, idx].sum(),
            'precision': all_stats[idx].precision,
            'recall': all_stats[idx].recall,
            'f1': all_stats[idx].f1,
            'n_true': all_stats[idx].n_true,
            'n_pred': all_stats[idx].n_pred,
            'mean_true_score': all_stats[idx].mean_true_score,
            'mean_matched_score': all_stats[idx].mean_matched_score,
            'panoptic_quality': all_stats[idx].panoptic_quality
        })

# Convert results to a DataFrame
results_df = pd.DataFrame(results)

# Save to CSV
results_df.to_csv('model_evaluation_results.csv', index=False)


#### TRAINING: CP_focus_lr0.2_epochs100 LR: 0.2  empochs: 100END###
#### VALIDATING: CP_focus_lr0.2_epochs100 LR: 0.2 empochs: 100START###


100%|██████████| 12/12 [00:00<00:00, 122.30it/s]
100%|██████████| 12/12 [00:00<00:00, 109.79it/s]
100%|██████████| 12/12 [00:00<00:00, 118.55it/s]
100%|██████████| 12/12 [00:00<00:00, 112.65it/s]
100%|██████████| 12/12 [00:00<00:00, 113.57it/s]
100%|██████████| 12/12 [00:00<00:00, 115.26it/s]
100%|██████████| 12/12 [00:00<00:00, 111.50it/s]
100%|██████████| 12/12 [00:00<00:00, 98.13it/s] 
100%|██████████| 12/12 [00:00<00:00, 108.20it/s]


#### TRAINING: CP_focus_lr0.2_epochs250 LR: 0.2 epochs: 250 START###


100%|██████████| 5/5 [00:00<00:00, 13.28it/s]
100%|██████████| 12/12 [00:04<00:00,  2.58it/s]


#### TRAINING: CP_focus_lr0.2_epochs250 LR: 0.2  empochs: 250END###
#### VALIDATING: CP_focus_lr0.2_epochs250 LR: 0.2 empochs: 250START###


100%|██████████| 12/12 [00:00<00:00, 98.58it/s] 
100%|██████████| 12/12 [00:00<00:00, 113.82it/s]
100%|██████████| 12/12 [00:00<00:00, 115.40it/s]
100%|██████████| 12/12 [00:00<00:00, 109.13it/s]
100%|██████████| 12/12 [00:00<00:00, 123.58it/s]
100%|██████████| 12/12 [00:00<00:00, 116.32it/s]
100%|██████████| 12/12 [00:00<00:00, 116.88it/s]
100%|██████████| 12/12 [00:00<00:00, 124.43it/s]
100%|██████████| 12/12 [00:00<00:00, 122.20it/s]


#### TRAINING: CP_focus_lr0.2_epochs500 LR: 0.2 epochs: 500 START###


100%|██████████| 5/5 [00:00<00:00, 14.43it/s]
100%|██████████| 12/12 [00:04<00:00,  2.62it/s]


#### TRAINING: CP_focus_lr0.2_epochs500 LR: 0.2  empochs: 500END###
#### VALIDATING: CP_focus_lr0.2_epochs500 LR: 0.2 empochs: 500START###


100%|██████████| 12/12 [00:00<00:00, 70.71it/s]
100%|██████████| 12/12 [00:00<00:00, 70.89it/s]
100%|██████████| 12/12 [00:00<00:00, 67.72it/s]
100%|██████████| 12/12 [00:00<00:00, 72.08it/s]
100%|██████████| 12/12 [00:00<00:00, 72.46it/s]
100%|██████████| 12/12 [00:00<00:00, 71.93it/s]
100%|██████████| 12/12 [00:00<00:00, 100.24it/s]
100%|██████████| 12/12 [00:00<00:00, 104.31it/s]
100%|██████████| 12/12 [00:00<00:00, 109.45it/s]


#### TRAINING: CP_focus_lr0.1_epochs10 LR: 0.1 epochs: 10 START###


100%|██████████| 5/5 [00:00<00:00, 13.83it/s]
100%|██████████| 12/12 [00:03<00:00,  3.27it/s]


#### TRAINING: CP_focus_lr0.1_epochs10 LR: 0.1  empochs: 10END###
#### VALIDATING: CP_focus_lr0.1_epochs10 LR: 0.1 empochs: 10START###


100%|██████████| 12/12 [00:00<00:00, 69.02it/s]
100%|██████████| 12/12 [00:00<00:00, 68.30it/s]
100%|██████████| 12/12 [00:00<00:00, 68.53it/s]
100%|██████████| 12/12 [00:00<00:00, 64.07it/s]
100%|██████████| 12/12 [00:00<00:00, 70.60it/s]
100%|██████████| 12/12 [00:00<00:00, 72.84it/s]
100%|██████████| 12/12 [00:00<00:00, 69.49it/s]
100%|██████████| 12/12 [00:00<00:00, 71.35it/s]
100%|██████████| 12/12 [00:00<00:00, 74.02it/s]


#### TRAINING: CP_focus_lr0.1_epochs100 LR: 0.1 epochs: 100 START###


100%|██████████| 5/5 [00:00<00:00, 11.19it/s]
100%|██████████| 12/12 [00:03<00:00,  3.31it/s]


#### TRAINING: CP_focus_lr0.1_epochs100 LR: 0.1  empochs: 100END###
#### VALIDATING: CP_focus_lr0.1_epochs100 LR: 0.1 empochs: 100START###


100%|██████████| 12/12 [00:00<00:00, 114.74it/s]
100%|██████████| 12/12 [00:00<00:00, 71.75it/s]
100%|██████████| 12/12 [00:00<00:00, 74.58it/s]
100%|██████████| 12/12 [00:00<00:00, 74.75it/s]
100%|██████████| 12/12 [00:00<00:00, 69.00it/s]
100%|██████████| 12/12 [00:00<00:00, 69.67it/s]
100%|██████████| 12/12 [00:00<00:00, 68.11it/s]
100%|██████████| 12/12 [00:00<00:00, 67.09it/s]
100%|██████████| 12/12 [00:00<00:00, 69.48it/s]


#### TRAINING: CP_focus_lr0.1_epochs250 LR: 0.1 epochs: 250 START###


100%|██████████| 5/5 [00:00<00:00,  9.53it/s]
100%|██████████| 12/12 [00:04<00:00,  2.72it/s]


#### TRAINING: CP_focus_lr0.1_epochs250 LR: 0.1  empochs: 250END###
#### VALIDATING: CP_focus_lr0.1_epochs250 LR: 0.1 empochs: 250START###


100%|██████████| 12/12 [00:00<00:00, 62.80it/s]
100%|██████████| 12/12 [00:00<00:00, 65.85it/s]
100%|██████████| 12/12 [00:00<00:00, 66.86it/s]
100%|██████████| 12/12 [00:00<00:00, 69.14it/s]
100%|██████████| 12/12 [00:00<00:00, 66.49it/s]
100%|██████████| 12/12 [00:00<00:00, 67.57it/s]
100%|██████████| 12/12 [00:00<00:00, 67.29it/s]
100%|██████████| 12/12 [00:00<00:00, 75.01it/s]
100%|██████████| 12/12 [00:00<00:00, 73.41it/s]


#### TRAINING: CP_focus_lr0.1_epochs500 LR: 0.1 epochs: 500 START###


100%|██████████| 5/5 [00:00<00:00, 10.47it/s]
100%|██████████| 12/12 [00:04<00:00,  2.94it/s]


#### TRAINING: CP_focus_lr0.1_epochs500 LR: 0.1  empochs: 500END###
#### VALIDATING: CP_focus_lr0.1_epochs500 LR: 0.1 empochs: 500START###


100%|██████████| 12/12 [00:00<00:00, 110.23it/s]
100%|██████████| 12/12 [00:00<00:00, 104.53it/s]
100%|██████████| 12/12 [00:00<00:00, 111.20it/s]
100%|██████████| 12/12 [00:00<00:00, 103.81it/s]
100%|██████████| 12/12 [00:00<00:00, 108.11it/s]
100%|██████████| 12/12 [00:00<00:00, 111.37it/s]
100%|██████████| 12/12 [00:00<00:00, 98.19it/s] 
100%|██████████| 12/12 [00:00<00:00, 115.85it/s]
100%|██████████| 12/12 [00:00<00:00, 122.12it/s]


#### TRAINING: CP_focus_lr0.01_epochs10 LR: 0.01 epochs: 10 START###


100%|██████████| 5/5 [00:00<00:00, 13.90it/s]
100%|██████████| 12/12 [00:03<00:00,  3.41it/s]


#### TRAINING: CP_focus_lr0.01_epochs10 LR: 0.01  empochs: 10END###
#### VALIDATING: CP_focus_lr0.01_epochs10 LR: 0.01 empochs: 10START###


100%|██████████| 12/12 [00:00<00:00, 73.35it/s]
100%|██████████| 12/12 [00:00<00:00, 75.00it/s]
100%|██████████| 12/12 [00:00<00:00, 72.32it/s]
100%|██████████| 12/12 [00:00<00:00, 71.85it/s]
100%|██████████| 12/12 [00:00<00:00, 75.62it/s]
100%|██████████| 12/12 [00:00<00:00, 76.84it/s]
100%|██████████| 12/12 [00:00<00:00, 123.18it/s]
100%|██████████| 12/12 [00:00<00:00, 123.10it/s]
100%|██████████| 12/12 [00:00<00:00, 115.40it/s]


#### TRAINING: CP_focus_lr0.01_epochs100 LR: 0.01 epochs: 100 START###


100%|██████████| 5/5 [00:00<00:00, 14.33it/s]
100%|██████████| 12/12 [00:03<00:00,  3.29it/s]


#### TRAINING: CP_focus_lr0.01_epochs100 LR: 0.01  empochs: 100END###
#### VALIDATING: CP_focus_lr0.01_epochs100 LR: 0.01 empochs: 100START###


100%|██████████| 12/12 [00:00<00:00, 68.27it/s]
100%|██████████| 12/12 [00:00<00:00, 67.25it/s]
100%|██████████| 12/12 [00:00<00:00, 64.67it/s]
100%|██████████| 12/12 [00:00<00:00, 68.74it/s]
100%|██████████| 12/12 [00:00<00:00, 71.58it/s]
100%|██████████| 12/12 [00:00<00:00, 65.97it/s]
100%|██████████| 12/12 [00:00<00:00, 71.23it/s]
100%|██████████| 12/12 [00:00<00:00, 70.30it/s]
100%|██████████| 12/12 [00:00<00:00, 74.92it/s]


#### TRAINING: CP_focus_lr0.01_epochs250 LR: 0.01 epochs: 250 START###


100%|██████████| 5/5 [00:00<00:00,  9.76it/s]
100%|██████████| 12/12 [00:03<00:00,  3.04it/s]


#### TRAINING: CP_focus_lr0.01_epochs250 LR: 0.01  empochs: 250END###
#### VALIDATING: CP_focus_lr0.01_epochs250 LR: 0.01 empochs: 250START###


100%|██████████| 12/12 [00:00<00:00, 67.21it/s]
100%|██████████| 12/12 [00:00<00:00, 71.26it/s]
100%|██████████| 12/12 [00:00<00:00, 66.48it/s]
100%|██████████| 12/12 [00:00<00:00, 70.14it/s]
100%|██████████| 12/12 [00:00<00:00, 71.75it/s]
100%|██████████| 12/12 [00:00<00:00, 67.99it/s]
100%|██████████| 12/12 [00:00<00:00, 67.15it/s]
100%|██████████| 12/12 [00:00<00:00, 64.01it/s]
100%|██████████| 12/12 [00:00<00:00, 74.29it/s]


#### TRAINING: CP_focus_lr0.01_epochs500 LR: 0.01 epochs: 500 START###


100%|██████████| 5/5 [00:00<00:00, 10.01it/s]
100%|██████████| 12/12 [00:03<00:00,  3.21it/s]


#### TRAINING: CP_focus_lr0.01_epochs500 LR: 0.01  empochs: 500END###
#### VALIDATING: CP_focus_lr0.01_epochs500 LR: 0.01 empochs: 500START###


100%|██████████| 12/12 [00:00<00:00, 104.21it/s]
100%|██████████| 12/12 [00:00<00:00, 111.79it/s]
100%|██████████| 12/12 [00:00<00:00, 111.01it/s]
100%|██████████| 12/12 [00:00<00:00, 103.34it/s]
100%|██████████| 12/12 [00:00<00:00, 100.42it/s]
100%|██████████| 12/12 [00:00<00:00, 102.24it/s]
100%|██████████| 12/12 [00:00<00:00, 111.50it/s]
100%|██████████| 12/12 [00:00<00:00, 114.64it/s]
100%|██████████| 12/12 [00:00<00:00, 113.96it/s]
