In [None]:
# generate cellpose dataset
from spacr.io import prepare_cellpose_dataset

input_root = 'path'

prepare_cellpose_dataset(input_root, augment_data=True, train_fraction=0.8, n_jobs=None)

In [None]:
# train cellpose model
from spacr.submodules import train_cellpose
%matplotlib inline
#'/nas_mnt/carruthers/training_data/plaque/cellpose_dataset'
settings = {'src':'/nas_mnt/carruthers/training_data/plaque/test',
            'test':False,
            'normalize':False,
            'percentiles':None,
            'invert':False,
            'grayscale':True,
            'rescale':False,
            'circular':False,
            'channels':[0,0],
            'model_name':'test',
            'model_type':'cyto',
            'Signal_to_noise':10,
            'background':200,
            'remove_background':False,
            'learning_rate':0.2,
            'weight_decay':1e-05,
            'batch_size':8,
            'n_epochs':25000,
            'from_scratch':False,
            'diameter':30,
            'resize':False,
            'target_dimensions':1000,
            'verbose':True}

train_cellpose(settings)

In [None]:
def _load_images_and_labels(image_files, label_files, invert=False):
    
    from .utils import invert_image
    
    images = []
    labels = []

    image_names = sorted([os.path.basename(f) for f in image_files]) if image_files else []
    label_names = sorted([os.path.basename(f) for f in label_files]) if label_files else []

    if image_files and label_files:
        for img_file, lbl_file in zip(image_files, label_files):
            image = cellpose.io.imread(img_file)
            if image is None:
                print(f"WARNING: Could not load image: {img_file}")
                continue
            if invert:
                image = invert_image(image)
            if image.max() > 1:
                image = image / image.max()

            label = cellpose.io.imread(lbl_file)
            if label is None:
                print(f"WARNING: Could not load label: {lbl_file}")
                continue

            images.append(image)
            labels.append(label)

    elif image_files:
        for img_file in image_files:
            image = cellpose.io.imread(img_file)
            if image is None:
                print(f"WARNING: Could not load image: {img_file}")
                continue
            if invert:
                image = invert_image(image)
            if image.max() > 1:
                image = image / image.max()
            images.append(image)

    elif label_files:
        for lbl_file in label_files:
            label = cellpose.io.imread(lbl_file)
            if label is None:
                print(f"WARNING: Could not load label: {lbl_file}")
                continue
            labels.append(label)

    image_dir = os.path.dirname(image_files[0]) if image_files else None
    label_dir = os.path.dirname(label_files[0]) if label_files else None

    print(f'Loaded {len(images)} images and {len(labels)} labels from {image_dir} and {label_dir}')
    if images and labels:
        print(f'image shape: {images[0].shape}, image type: {images[0].dtype}; '
              f'label shape: {labels[0].shape}, label type: {labels[0].dtype}')

    return images, labels, image_names, label_names

def _load_normalized_images_and_labels(image_files, label_files, channels=None, percentiles=None,  
                                       invert=False, visualize=False, remove_background=False, 
                                       background=0, Signal_to_noise=10, target_height=None, target_width=None):
    
    from .plot import normalize_and_visualize, plot_resize
    from .utils import invert_image, apply_mask
    from skimage.transform import resize as resizescikit

    # Ensure percentiles are valid
    if isinstance(percentiles, list) and len(percentiles) == 2:
        try:
            percentiles = [int(percentiles[0]), int(percentiles[1])]
        except ValueError:
            percentiles = None
    else:
        percentiles = None

    signal_thresholds = float(background) * float(Signal_to_noise)
    lower_percentile = 2

    images, labels, orig_dims = [], [], []
    num_channels = 4
    percentiles_1 = [[] for _ in range(num_channels)]
    percentiles_99 = [[] for _ in range(num_channels)]

    image_names = [os.path.basename(f) for f in image_files]
    image_dir = os.path.dirname(image_files[0])

    if label_files is not None:
        label_names = [os.path.basename(f) for f in label_files]
        label_dir = os.path.dirname(label_files[0])
    else:
        label_names, label_dir = [], None

    # Load, normalize, and resize images
    for i, img_file in enumerate(image_files):
        image = cellpose.io.imread(img_file)
        orig_dims.append((image.shape[0], image.shape[1]))

        if invert:
            image = invert_image(image)

        # Select specific channels if needed
        if channels is not None and image.ndim == 3:
            image = image[..., channels]

        if remove_background:
            image = np.where(image < background, 0, image)

        if image.ndim < 3:
            image = np.expand_dims(image, axis=-1)

        # Calculate percentiles if not provided
        if percentiles is None:
            for c in range(image.shape[-1]):
                p1 = np.percentile(image[..., c], lower_percentile)
                percentiles_1[c].append(p1)

                # Ensure `signal_thresholds` and `p` are floats for comparison
                for percentile in [98, 99, 99.9, 99.99, 99.999]:
                    p = np.percentile(image[..., c], percentile)
                    if float(p) > signal_thresholds:
                        percentiles_99[c].append(p)
                        break

        # Resize image if required
        if target_height and target_width:
            image_shape = (target_height, target_width) if image.ndim == 2 else (target_height, target_width, image.shape[-1])
            image = resizescikit(image, image_shape, preserve_range=True, anti_aliasing=True).astype(image.dtype)

        images.append(image)

    # Calculate average percentiles if needed
    if percentiles is None:
        avg_p1 = [np.mean(p) for p in percentiles_1]
        avg_p99 = [np.mean(p) if p else avg_p1[i] for i, p in enumerate(percentiles_99)]

        print(f'Average 1st percentiles: {avg_p1}, Average 99th percentiles: {avg_p99}')

        normalized_images = [
            np.stack([rescale_intensity(img[..., c], in_range=(avg_p1[c], avg_p99[c]), out_range=(0, 1))
                      for c in range(img.shape[-1])], axis=-1) for img in images
        ]

    else:
        normalized_images = [
            np.stack([rescale_intensity(img[..., c], 
                                        in_range=(np.percentile(img[..., c], percentiles[0]),
                                                  np.percentile(img[..., c], percentiles[1])), 
                                        out_range=(0, 1)) for c in range(img.shape[-1])], axis=-1) 
            for img in images
        ]

    # Load and resize labels if provided
    if label_files is not None:
        labels = [resizescikit(cellpose.io.imread(lbl_file), 
                               (target_height, target_width) if target_height and target_width else orig_dims[i], 
                               order=0, preserve_range=True, anti_aliasing=False).astype(np.uint8)
                  for i, lbl_file in enumerate(label_files)]

    print(f'Loaded and normalized {len(normalized_images)} images and {len(labels)} labels from {image_dir} and {label_dir}')

    if visualize and images and labels:
        plot_resize(images, normalized_images, labels, labels)

    return normalized_images, labels, image_names, label_names, orig_dims



def train_cellpose(settings):
    
    from spacr.io import _load_normalized_images_and_labels, _load_images_and_labels
    from spacr.settings import get_train_cellpose_default_settings
    from spacr.utils import save_settings
    
    settings = get_train_cellpose_default_settings(settings)
    
    img_src = settings['src']
    
    img_src = os.path.join(settings['src'],'train', 'images')
    mask_src = os.path.join(settings['src'], 'train', 'masks')
    test_img_src = os.path.join(settings['src'],'test', 'images')
    test_mask_src = os.path.join(settings['src'], 'test', 'masks')
    
    if settings['resize']:
        target_dimensions = settings['width_dimensions']

    if settings['test']:
        if os.path.exists(test_img_src) and os.path.exists(test_mask_src):
            print(f"Found test set")
        else:
            print(f"could not find test folders: {test_img_src} and {test_mask_src}")
            return

    test_images, test_masks, test_image_names, test_mask_names = None,None,None,None

    if settings['from_scratch']:
        model_name=f"scratch_{settings['model_name']}_{settings['model_type']}_e{settings['n_epochs']}_X{target_dimensions}_Y{target_dimensions}.CP_model"
    else:
        if settings['resize']:
            model_name=f"{settings['model_name']}_{settings['model_type']}_e{settings['n_epochs']}_X{target_dimensions}_Y{target_dimensions}.CP_model"
        else:
            model_name=f"{settings['model_name']}_{settings['model_type']}_e{settings['n_epochs']}.CP_model"

    model_save_path = os.path.join(settings['src'], 'models', 'cellpose_model')
    
    print(model_save_path)
    
    os.makedirs(model_save_path, exist_ok=True)
    save_settings(settings, name=f"{model_name}")
    
    if settings['from_scratch']:
        model = cp_models.CellposeModel(gpu=True, model_type=settings['model_type'], diam_mean=settings['diameter'], pretrained_model=None)
    else:
        model = cp_models.CellposeModel(gpu=True, model_type=settings['model_type'])
        
    if settings['normalize']:
        image_files = [os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')]
        label_files = [os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')]
        images, masks, image_names, mask_names, orig_dims = _load_normalized_images_and_labels(image_files, 
                                                                                               label_files, 
                                                                                               settings['channels'], 
                                                                                               settings['percentiles'],  
                                                                                               settings['invert'], 
                                                                                               settings['verbose'], 
                                                                                               settings['remove_background'], 
                                                                                               settings['background'], 
                                                                                               settings['Signal_to_noise'], 
                                                                                               settings['target_dimensions'], 
                                                                                               settings['target_dimensions'])        
        images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
        
        if settings['test']:
            test_image_files = [os.path.join(test_img_src, f) for f in os.listdir(test_img_src) if f.endswith('.tif')]
            test_label_files = [os.path.join(test_mask_src, f) for f in os.listdir(test_mask_src) if f.endswith('.tif')]
            test_images, test_masks, test_image_names, test_mask_names = _load_normalized_images_and_labels(test_image_files, 
                                                                                                            test_label_files, 
                                                                                                            settings['channels'], 
                                                                                                            settings['percentiles'],  
                                                                                                            settings['invert'], 
                                                                                                            settings['verbose'], 
                                                                                                            settings['remove_background'], 
                                                                                                            settings['background'], 
                                                                                                            settings['Signal_to_noise'], 
                                                                                                            settings['target_dimensions'], 
                                                                                                            settings['target_dimensions'])
            test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
            
    else:
        image_files = sorted([os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')])
        label_files = sorted([os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')])
        images, masks, image_names, mask_names = _load_images_and_labels(image_files, label_files, settings['invert'])
        images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
        
        if settings['test']:
            test_image_files = sorted([os.path.join(test_img_src, f) for f in os.listdir(test_img_src) if f.endswith('.tif')])
            test_label_files = sorted([os.path.join(test_mask_src, f) for f in os.listdir(test_mask_src) if f.endswith('.tif')])
            test_images, test_masks, test_image_names, test_mask_names = _load_images_and_labels(test_image_files, test_label_files, settings['invert'])
            test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
    
    #if resize:
    #    images, masks = resize_images_and_labels(images, masks, target_height, target_width, show_example=True)

    if settings['model_type'] == 'cyto':
        cp_channels = [0,1]
    if settings['model_type'] == 'cyto2':
        cp_channels = [0,2]
    if settings['model_type'] == 'cyto3':
        cp_channels = [0,2]
    if settings['model_type'] == 'nucleus':
        cp_channels = [0,0]
    if settings['grayscale']:
        cp_channels = [0,0]
        images = [np.squeeze(img) if img.ndim == 3 and 1 in img.shape else img for img in images]
    
    masks = [np.squeeze(mask) if mask.ndim == 3 and 1 in mask.shape else mask for mask in masks]

    print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {masks[0].shape}, image type: masks[0].shape')
    save_every = int(settings['n_epochs']/10)
    if save_every < 10:
        save_every = settings['n_epochs']
        
    if settings['test'] and (not test_images or not test_masks):
        print("WARNING: Test data missing or empty. Proceeding without test set.")
        test_images, test_masks = None, None
        test_image_names, test_mask_names = None, None

    train_cp.train_seg(model.net,
                    train_data=images,
                    train_labels=masks,
                    train_files=image_names,
                    train_labels_files=mask_names,
                    train_probs=None,
                    test_data=test_images,
                    test_labels=test_masks,
                    test_files=test_image_names,
                    test_labels_files=test_mask_names, 
                    test_probs=None,
                    load_files=True,
                    batch_size=settings['batch_size'],
                    learning_rate=settings['learning_rate'],
                    n_epochs=settings['n_epochs'],
                    weight_decay=settings['weight_decay'],
                    momentum=0.9,
                    SGD=False,
                    channels=cp_channels,
                    channel_axis=None,
                    normalize=False, 
                    compute_flows=False,
                    save_path=model_save_path,
                    save_every=save_every,
                    nimg_per_epoch=None,
                    nimg_test_per_epoch=None,
                    rescale=settings['rescale'],
                    min_train_masks=1,
                    model_name=settings['model_name'])

    return print(f"Model saved at: {model_save_path}/{model_name}")

In [None]:
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from spacr.settings import get_train_cellpose_default_settings
from spacr.utils import save_settings, invert_image
from cellpose import models as cp_models
from cellpose import io as cp_io
from cellpose import train as train_cp
from skimage.transform import resize as sk_resize
from skimage.exposure import rescale_intensity

class CellposeLazyDataset(Dataset):
    def __init__(self, image_files, label_files, settings):
        self.image_files = image_files
        self.label_files = label_files
        self.settings = settings
        self.normalize = settings['normalize']
        self.invert = settings['invert']
        self.target_dimensions = settings['width_dimensions']
        self.percentiles = settings.get('percentiles', [2, 99])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image = cp_io.imread(self.image_files[idx])
        label = cp_io.imread(self.label_files[idx]) if self.label_files else None

        if self.invert:
            image = invert_image(image)

        if image.max() > 1:
            image = image / image.max()

        if self.normalize:
            lower_p, upper_p = np.percentile(image, self.percentiles)
            image = rescale_intensity(image, in_range=(lower_p, upper_p), out_range=(0, 1))

        image_shape = (self.target_dimensions, self.target_dimensions)
        image = sk_resize(image, image_shape, preserve_range=True, anti_aliasing=True).astype(image.dtype)
        if label is not None:
            label = sk_resize(label, image_shape, order=0, preserve_range=True, anti_aliasing=False).astype(np.uint8)

        image = np.squeeze(image) if image.ndim == 3 and image.shape[-1] == 1 else image
        if label is not None:
            label = np.squeeze(label) if label.ndim == 3 and label.shape[-1] == 1 else label

        return image, label

def train_cellpose(settings):
    settings = get_train_cellpose_default_settings(settings)

    img_src = os.path.join(settings['src'], 'train', 'images')
    mask_src = os.path.join(settings['src'], 'train', 'masks')

    target_dimensions = settings['width_dimensions']

    model_name = f"{settings['model_name']}_{settings['model_type']}_e{settings['n_epochs']}_X{target_dimensions}_Y{target_dimensions}"
    if settings['from_scratch']:
        model_name = f"scratch_{model_name}"
    model_name += ".CP_model"

    model_save_path = os.path.join(settings['src'], 'models', 'cellpose_model')
    os.makedirs(model_save_path, exist_ok=True)

    save_settings(settings, name=model_name)

    model = cp_models.CellposeModel(gpu=True, 
                                    model_type=settings['model_type'], 
                                    diam_mean=settings['diameter'], 
                                    pretrained_model=None if settings['from_scratch'] else settings['model_type'])

    cp_channels = {'cyto': [0, 1], 'cyto2': [0, 2], 'cyto3': [0, 2], 'nucleus': [0, 0]}.get(settings['model_type'], [0, 0])
    if settings['grayscale']:
        cp_channels = [0, 0]

    train_image_files = sorted([os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')])
    train_label_files = sorted([os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')])

    train_dataset = CellposeLazyDataset(train_image_files, train_label_files, settings=settings)

    images, labels = zip(*[train_dataset[i] for i in range(len(train_dataset))])
    
    print(len(images))

    # Plot first batch of images and labels
    fig, axs = plt.subplots(2, min(len(images), 4), figsize=(12, 6))
    for i in range(min(len(images), 4)):
        axs[0, i].imshow(images[i], cmap='gray')
        axs[0, i].set_title(f'Image {i+1}')
        axs[0, i].axis('off')
        axs[1, i].imshow(labels[i], cmap='gray')
        axs[1, i].set_title(f'Label {i+1}')
        axs[1, i].axis('off')
    plt.show()

    train_cp.train_seg(model.net,
                       train_data=images,
                       train_labels=labels,
                       channels=cp_channels,
                       save_path=model_save_path,
                       n_epochs=settings['n_epochs'],
                       batch_size=settings['batch_size'],
                       learning_rate=settings['learning_rate'],
                       weight_decay=settings['weight_decay'],
                       model_name=model_name,
                       save_every=max(1, (settings['n_epochs'] // 10)),
                       rescale=settings['rescale'])

    return print(f"Model saved at: {model_save_path}/{model_name}")

In [None]:
# train cellpose model
#from spacr.submodules import train_cellpose
%matplotlib inline
#'/nas_mnt/carruthers/training_data/plaque/cellpose_dataset'
settings = {'src':'/nas_mnt/carruthers/training_data/plaque/test',
            'test':False,
            'normalize':False,
            'percentiles':None,
            'invert':False,
            'grayscale':True,
            'rescale':False,
            'circular':False,
            'channels':[0,0],
            'model_name':'test',
            'model_type':'cyto',
            'Signal_to_noise':10,
            'background':200,
            'remove_background':False,
            'learning_rate':0.2,
            'weight_decay':1e-05,
            'batch_size':8,
            'n_epochs':100,
            'from_scratch':False,
            'diameter':30,
            'resize':False,
            'width_dimensions':1000,
            'verbose':True}

train_cellpose(settings)