In [1]:
import os

MODEL_NAMES = [
    "cyto3", "nuclei", "cyto2_cp3", "tissuenet_cp3", "livecell_cp3", "yeast_PhC_cp3",
    "yeast_BF_cp3", "bact_phase_cp3", "bact_fluor_cp3", "deepbacs_cp3", "cyto2", "cyto", "CPx",
    "transformer_cp3", "neurips_cellpose_default", "neurips_cellpose_transformer",
    "neurips_grayscale_cyto2"
]

save_path = '/root/capsule/scratch/'

for model_name in MODEL_NAMES:
    model_save_path = os.path.join(save_path, model_name)
    if not os.path.exists(model_save_path):
        os.makedirs(model_save_path)

In [2]:
import albumentations as A
import os
import cv2
import numpy as np
import tifffile
from sklearn.model_selection import train_test_split
from cellpose import models, train, io
import matplotlib.pyplot as plt
from cellpose import plot
from cellpose import utils, io

io.logger_setup()  # Run this to get printing of progress

# Define paths
data_dir = '/root/capsule/data/iGluSnFR_Soma_Annotation'

# Collect all image and mask file paths
image_files = sorted([f for f in os.listdir(data_dir) if f.endswith('_merged.tif')])
mask_files = sorted([f for f in os.listdir(data_dir) if f.endswith('_segmented_v2.tif')])

# Ensure that each image has a corresponding mask
assert len(image_files) == len(mask_files), "Number of images and masks must match."

# Load all images and masks
images = [tifffile.imread(os.path.join(data_dir, img))[:, 1, :, :] for img in image_files]
masks = [tifffile.imread(os.path.join(data_dir, msk)) for msk in mask_files]

# Ensure images and masks have the same number of frames
for img, msk in zip(images, masks):
    assert img.shape[0] == msk.shape[0], "Number of frames in images and masks must match."

# Convert lists to numpy arrays
images = np.concatenate(images, axis=0)
masks = np.concatenate(masks, axis=0)

# Normalize images to 0-1 range
images = images.astype(np.float32) / 255.0

# Convert masks to uint8 if needed
masks_uint8 = masks.astype(np.uint8)

# Define an augmentation pipeline
transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Rotate(limit=45, p=0.5),
], is_check_shapes=False)

augmented_images = []
augmented_masks = []

# Augment each image multiple times
num_augmentations = 5  # Number of times to augment each image

for img, msk in zip(images, masks_uint8):
    for _ in range(num_augmentations):
        # Apply the augmentation pipeline
        transformed = transform(image=img, mask=msk)
        augmented_images.append(transformed['image'])
        augmented_masks.append(transformed['mask'])

# Convert lists to numpy arrays and combine with original data
augmented_images = np.array(augmented_images)
augmented_masks = np.array(augmented_masks)

# Combine original and augmented data
images_combined = np.concatenate((images, augmented_images), axis=0)
masks_combined = np.concatenate((masks_uint8, augmented_masks), axis=0)

# Split data into train+val and test
train_val_images, test_images, train_val_masks, test_masks = train_test_split(
    images_combined, masks_combined, test_size=0.15, random_state=42
)

# Split train+val into train and validation
train_images, val_images, train_masks, val_masks = train_test_split(
    train_val_images, train_val_masks, test_size=0.176, random_state=42  # 0.176 to make validation 15% of total
)

  check_for_updates()


2024-09-15 20:18:46,494 [INFO] WRITING LOG OUTPUT TO /root/.cellpose/run.log
2024-09-15 20:18:46,495 [INFO] 
cellpose version: 	3.0.11 
platform:       	linux 
python version: 	3.10.12 
torch version:  	2.1.0


In [3]:
import pandas as pd
import matplotlib.pyplot as plt
from cellpose import models, train, io
import torch
import gc

# Initialize an empty DataFrame to hold all results
all_results_df = pd.DataFrame()

for model_name in MODEL_NAMES:
    # Initialize Cellpose model
    model = models.CellposeModel(gpu=True, model_type=model_name)

    # Train the model (example; adjust parameters as needed)
    train.train_seg(
        model.net,
        train_data=train_images,
        train_labels=train_masks,
        test_data=val_images, 
        test_labels=val_masks,
        channels=[0, 0],  # Adjust channels if needed
        normalize=True,
        weight_decay=1e-4,
        SGD=False,
        learning_rate=0.1,
        n_epochs=1000,
        save_path=save_path,
        model_name=f'{model_name}_cellpose_model.pth'
    )

    # Evaluate the model and calculate metrics
    results_list = []
    for i, (image, true_mask) in enumerate(zip(test_images, test_masks)):
        results = model.eval(image, channels=[0, 0])
        if len(results) == 3:
            masks_pred, flows, styles = results
        else:
            masks_pred, flows, styles, diams = results

        metrics = calculate_metrics(true_mask, masks_pred)
        results_list.append({
            'Image_Index': i,
            'Dice_Score': metrics['Dice Score'],
            'IoU_Score': metrics['IoU Score'],
            'Pixel_Accuracy': metrics['Pixel Accuracy'],
            'Number_of_True_ROIs': metrics['Number of True ROIs'],
            'Number_of_Predicted_ROIs': metrics['Number of Predicted ROIs']
        })

    # After training and visualization, delete references
    del model
    # Run garbage collector
    gc.collect()
    # Clear CUDA cache
    torch.cuda.empty_cache()

    # Convert the list of results to a DataFrame
    results_df = pd.DataFrame(results_list)

    # Append model name to results and add to all_results_df
    results_df['Model_Name'] = model_name
    all_results_df = pd.concat([all_results_df, results_df], ignore_index=True)

    # Save model-specific results to a CSV file
    results_df.to_csv(os.path.join(save_path, f'{model_name}_results.csv'), index=False)

    # Visualize and save plots for a subset of test images
    random_indices = random.sample(range(len(test_images)), 20)
    for idx in random_indices:
        results = model.eval(test_images[idx], channels=[0, 0])
        if len(results) == 3:
            masks_pred, flows, styles = results
        else:
            masks_pred, flows, styles, diams = results

        metrics = calculate_metrics(test_masks[idx], masks_pred)
        title = (f"Image_Index: {idx}, "
                f"Dice_Score: {metrics['Dice Score']:.2f}, "
                f"IoU_Score: {metrics['IoU Score']:.2f}, "
                f"Pixel_Accuracy: {metrics['Pixel Accuracy']:.2f}, "
                f"Number_of_True_ROIs: {metrics['Number of True ROIs']}, "
                f"Number_of_Predicted_ROIs: {metrics['Number of Predicted ROIs']}")

        fig, ax = plt.subplots(1, 4, figsize=(16, 6))
        fig.suptitle(title, fontsize=12)

        ax[0].imshow(test_images[idx])
        ax[0].set_title('Original Image')
        ax[0].axis('off')

        ax[1].imshow(test_images[idx], cmap='gray')
        ax[1].imshow(test_masks[idx], cmap='jet', alpha=0.5)
        ax[1].set_title('Ground Truth Mask')
        ax[1].axis('off')

        ax[2].imshow(test_images[idx], cmap='gray')
        ax[2].imshow(masks_pred, cmap='jet', alpha=0.5)
        ax[2].set_title('Predicted Mask')
        ax[2].axis('off')

        ax[3].imshow(flows[0], cmap='gray')
        ax[3].set_title('Flow Field')
        ax[3].axis('off')

        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.savefig(os.path.join(save_path, f'{model_name}_image_{idx}.png'), bbox_inches='tight')
        plt.close(fig)  # Close the figure to prevent it from being displayed

# Save combined results to a CSV file
all_results_df.to_csv(os.path.join(save_path, 'all_models_results.csv'), index=False)

2024-09-15 20:18:53,092 [INFO] >> cyto3 << model set to be used
2024-09-15 20:18:54,010 [INFO] ** TORCH CUDA version installed and working. **
2024-09-15 20:18:54,010 [INFO] >>>> using GPU (CUDA)
2024-09-15 20:18:54,095 [INFO] >>>> loading model /root/.cellpose/models/cyto3
2024-09-15 20:18:54,144 [INFO] >>>> model diam_mean =  30.000 (ROIs rescaled to this size during training)
2024-09-15 20:18:54,145 [INFO] computing flows for labels


  6%|▌         | 43/705 [00:03<00:44, 14.96it/s]



  9%|▉         | 63/705 [00:04<00:39, 16.35it/s]



 15%|█▍        | 103/705 [00:06<00:27, 21.99it/s]



 19%|█▉        | 134/705 [00:08<00:39, 14.62it/s]



 20%|██        | 144/705 [00:09<00:27, 20.34it/s]



 24%|██▍       | 171/705 [00:10<00:33, 15.99it/s]



 30%|██▉       | 209/705 [00:13<00:55,  8.94it/s]



 41%|████▏     | 292/705 [00:18<00:20, 20.53it/s]



 42%|████▏     | 298/705 [00:18<00:24, 16.91it/s]



 43%|████▎     | 302/705 [00:19<00:20, 19.91it/s]



 55%|█████▌    | 391/705 [00:24<00:17, 17.83it/s]



 56%|█████▋    | 397/705 [00:24<00:15, 20.35it/s]



 61%|██████▏   | 432/705 [00:26<00:18, 14.77it/s]



 67%|██████▋   | 474/705 [00:29<00:15, 14.68it/s]



 69%|██████▉   | 485/705 [00:30<00:16, 13.22it/s]



 73%|███████▎  | 515/705 [00:32<00:13, 13.90it/s]



 75%|███████▍  | 527/705 [00:33<00:12, 14.24it/s]



 92%|█████████▏| 646/705 [00:39<00:03, 16.71it/s]



 99%|█████████▉| 701/705 [00:43<00:00, 11.46it/s]



100%|██████████| 705/705 [00:43<00:00, 16.20it/s]

2024-09-15 20:19:38,785 [INFO] computing flows for labels



 35%|███▌      | 53/151 [00:03<00:04, 20.05it/s]



 38%|███▊      | 57/151 [00:03<00:04, 22.73it/s]



 48%|████▊     | 72/151 [00:04<00:03, 20.11it/s]



 52%|█████▏    | 79/151 [00:04<00:04, 17.50it/s]



 57%|█████▋    | 86/151 [00:04<00:03, 21.00it/s]



 85%|████████▍ | 128/151 [00:07<00:01, 18.02it/s]



100%|██████████| 151/151 [00:08<00:00, 17.35it/s]

2024-09-15 20:19:47,734 [INFO] >>> computing diameters



  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
100%|██████████| 705/705 [00:03<00:00, 217.85it/s]
100%|██████████| 151/151 [00:00<00:00, 218.01it/s]

2024-09-15 20:19:51,668 [INFO] >>> using channels [0, 0]
2024-09-15 20:19:51,668 [INFO] >>> normalizing {'lowhigh': None, 'percentile': None, 'normalize': True, 'norm3D': False, 'sharpen_radius': 0, 'smooth_radius': 0, 'tile_norm_blocksize': 0, 'tile_norm_smooth3D': 1, 'invert': False}





2024-09-15 20:20:18,317 [INFO] >>> n_epochs=1000, n_train=584, n_test=151
2024-09-15 20:20:18,317 [INFO] >>> AdamW, learning_rate=0.10000, weight_decay=0.00010
2024-09-15 20:20:18,558 [INFO] >>> saving model to /root/capsule/scratch/models/cyto3_cellpose_model.pth


  return F.conv2d(input, weight, bias, self.stride,


2024-09-15 20:20:32,862 [INFO] 0, train_loss=0.6359, test_loss=0.6461, LR=0.0000, time 14.31s
2024-09-15 20:21:31,721 [INFO] 5, train_loss=0.0943, test_loss=0.0749, LR=0.0556, time 73.16s
2024-09-15 20:22:30,486 [INFO] 10, train_loss=0.0794, test_loss=0.0722, LR=0.1000, time 131.93s
2024-09-15 20:24:27,006 [INFO] 20, train_loss=0.0747, test_loss=0.0725, LR=0.1000, time 248.45s
2024-09-15 20:26:22,696 [INFO] 30, train_loss=0.0714, test_loss=0.0607, LR=0.1000, time 364.14s
2024-09-15 20:28:18,913 [INFO] 40, train_loss=0.0697, test_loss=0.0734, LR=0.1000, time 480.36s
2024-09-15 20:30:14,948 [INFO] 50, train_loss=0.0701, test_loss=0.0750, LR=0.1000, time 596.39s
2024-09-15 20:32:11,295 [INFO] 60, train_loss=0.0718, test_loss=0.0676, LR=0.1000, time 712.74s
2024-09-15 20:34:07,378 [INFO] 70, train_loss=0.0708, test_loss=0.0603, LR=0.1000, time 828.82s
2024-09-15 20:36:03,608 [INFO] 80, train_loss=0.0701, test_loss=0.0624, LR=0.1000, time 945.05s
2024-09-15 20:37:59,952 [INFO] 90, train_los