Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ scikit-learn>=0.24.0
pandas>=1.3.0
opencv-python>=4.5.0
Pillow>=8.3.0
numba==0.61.0

# Deep learning
torch>=1.9.0
Expand All @@ -18,6 +19,7 @@ albumentations>=1.0.0
scikit-learn-extra>=0.2.0
connected-components-3d>=3.0.0
SimpleITK>=2.1.0
cellpose ==3.1.1.1

# Development and utilities
jupyterlab>=3.0.0
Expand All @@ -26,9 +28,9 @@ black>=21.6b0
isort>=5.9.0
flake8>=3.9.0
tqdm>=4.64.0
pyyaml==6.0.2
# Logging and monitoring
loguru>=0.7.0

# Optional GPU acceleration (install with: pip install cupy-cuda12x)
# cupy-cuda12x>=13.0.0 # Uncomment if you have CUDA 12.x GPU support
cellpose
7 changes: 3 additions & 4 deletions scripts/do_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,18 +94,17 @@ def main():
else:
pool = multiprocessing.Pool(processes=args.workers)
"""
# Completed: IF slide_id is provided, use the slide_id and data_dir to load the slides
# TODO: Multiprocessing data loader
log.logger.debug("Loading slides...")
slides = data_loader.load_slides(args.data_dir)

# TODO: A non offset based composite creation should be implemented in the data loader
log.logger.debug("Creating composites...")
composite_images = data_loader.get_composites(slides, config.SLIDE_INDEX_OFFSET) # creaing composites should be preprocessing which is in segmentor
composite_images = cellposeSegmentor.preprocess(slides)

log.logger.debug("Running Segmentation...")
binary_masks = cellposeSegmentor.segment(composite_images)

image_crops, mask_crops, centers = cellposeSegmentor.postprocess()

log.logger.debug("Saving masks...")
cellposeSegmentor.save_masks(binary_masks)

Expand Down
22 changes: 12 additions & 10 deletions src/deep_learning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Base class for all traditional segmentation algorithms.
"""
from abc import ABC, abstractmethod
from scipy import ndimage as ndi
import numpy as np

class BaseSegmenter(ABC):
"""Base class that all traditional segmentation algorithms should inherit from."""
Expand All @@ -17,7 +17,7 @@ def __init__(self, config=None):
self.config = config or {}

@abstractmethod
def segment(self, images):
def segment(self, images) -> np.ndarray:
"""
Segment the input images.

Expand All @@ -29,7 +29,8 @@ def segment(self, images):
"""
pass

def preprocess(self, images_dir): # get_composites shouuld be here
@abstractmethod
def preprocess(self, images) -> np.ndarray: # get_composites shouuld be here
"""
Preprocess the input image before segmentation.

Expand All @@ -41,15 +42,16 @@ def preprocess(self, images_dir): # get_composites shouuld be here
"""
pass

def postprocess(self, mask):
@abstractmethod
def postprocess(self, masks=None, images=None) -> list[np.ndarray]:
"""
Postprocess the segmentation mask.
Args:
mask (numpy.ndarray): Segmentation mask to postprocess.

Postprocess the segmentation mask. Extracts cropped cell images using the segmented masks.

Arguments:
masks (np.ndarray): Array of segmented masks with shape (N, C, H, W).
images (np.ndarray): Array of original images with shape (N, C, H, W).
Returns:
numpy.ndarray: Postprocessed mask.
List[np.ndarray]: List of cropped cell images.
"""
pass

183 changes: 161 additions & 22 deletions src/deep_learning/cellpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,187 @@
from pathlib import Path
import numpy as np
import cv2
import matplotlib.pyplot as plt # use in debug console
from src.deep_learning.base import BaseSegmenter
import os
import numpy as np
import cv2
import multiprocessing
from .base import BaseSegmenter
from .utils.config import Config
from .utils.loader import load_img
from .utils.image import compute_composite
from .utils.crop import crop_single_image
from .utils.mask import binary_masks
import loguru as log

class CellposeSegmentor(BaseSegmenter):
def __init__(self, config):
def __init__(self, config: Config):
"""
Initialize the Cellpose segmentor.

This class is a wrapper around the Cellpose deep learning model for image segmentation.
It inherits from BaseSegmenter and implements the segment method.

This class provides functionality for:
- Loading grayscale microscopy images from a directory
- Combining multi-channel scans into composite images
- Running segmentation using Cellpose
- Saving mask outputs
- Extracting cropped cell images from masks

Attributes:
model (cellpose.models.CellposeModel): The loaded Cellpose model.
config (Config): Configuration object containing paths and settings.

Methods:
load_images(image_dir): Loads images from a directory using multiprocessing.
combine_images(images): Combines 4-channel scans into RGB composites.
segment_frames(frames): Runs Cellpose segmentation on image frames.
save_masks(masks): Saves the predicted masks to disk.
get_cell_crops(masks, images): Extracts cropped cell images and their masks.
run(image_dir): Main workflow to segment images from a directory.
"""
self.config = config

if core.use_gpu() == False:
raise ImportError("No GPU access")

if not Path(self.config.DEEP_LEARNING_MODELS_DIR).exists():
log.logger.warning("Pretrained model path does not exist, using default model.")
self.config.DEEP_LEARNING_CONFIG["model"]["name"] = "cpsam" # Default model if not specified # not sure why syntax is so cursed

if self.config.MODEL == 'cellpose':
self.model = models.CellposeModel(gpu = True,
pretrained_model=str(Path(self.config.DEEP_LEARNING_MODELS_DIR, self.config.DEEP_LEARNING_CONFIG["model"]["name"])),
device=torch.device(self.config.DEEP_LEARNING_CONFIG["device"]))
if not self.config.data_dir.exists():
raise FileNotFoundError(f"Data directory {self.config.data_dir} does not exist")

else:
pass # For future addition of models
if self.config.pretrained_model is None:
raise ValueError("Pretrained model must be specified")

self.model = models.CellposeModel(gpu = True,
pretrained_model=str(self.config.pretrained_model), # ignore Pylance error, this code is correct
device=torch.device(self.config.device))
self.image_data = np.empty(1)
self.composite_data = np.empty(1)
self.masks = np.empty(1)
self.stacked_scans_data = []

log.logger.debug("Cellpose Segmentor initialized.")

def save_masks(self, masks):
if not Path(self.config.PROCESSED_DATA_DIR).exists():
self.config.PROCESSED_DATA_DIR.mkdir(parents=True, exist_ok=True)
def segment(self, images=None) -> np.ndarray:
"""
Segment the input images.

Args:
List of images (numpy.ndarray with shape NUM IMAGES * HEIGHT * WIDTH * 3): Input images to segment.

Returns:
numpy.ndarray: Insance mask where each cell gets its own ID
"""
if not images:
images = self.composite_data

self.masks, _, _ = self.model.eval(self.composite_data, diameter=15, channels=[0, 0])
return self.masks

def preprocess(self, images=None) -> np.ndarray:
"""
Preprocess the loaded input images before segmentation by combining different scan types into a BRG image understood by the segmentation module.

Args:
image (numpy.ndarray): Input image to preprocess.

Returns:
numpy.ndarray: Preprocessed image.
"""
if not images:
images = self.image_data

frames=[]
offset = int(len(images)/4)
for i in range(offset):
image0 = images[i]
image1 = images[i+offset]
image2 = images[i+2*offset]
# skip Bright Field scan
image3 = images[i+3*offset]
stacked = np.stack([image0, image1, image2, image3], axis=-1)
self.stacked_scans_data.append(stacked)
frames.append(compute_composite(image0, image1, image2, image3))

for i, mask in enumerate(masks):
mask_path = Path(self.config.PROCESSED_DATA_DIR, f"mask_{i}.png")
cv2.imwrite(mask_path, mask)
self.stacked_scans_data = np.stack(self.stacked_scans_data[1:], axis=0) # remove the first empty array
self.composite_data = np.ndarray(frames)
return np.ndarray(frames)

def segment(self, images):
masks, _, _ = self.model.eval(images,diameter=15,channels=[0, 0]) # test if pasing all the frames at once or one at a time is faster
def postprocess(self, masks=None, images=None) -> list[np.ndarray]:
"""
Postprocess the segmentation mask. Extracts cropped cell images using the segmented masks.

Arguments:
masks (np.ndarray): Array of segmented masks with shape (N, C, H, W).
images (np.ndarray): Array of original images with shape (N, C, H, W).
Returns:
List[np.ndarray]: List of cropped cell images.
"""
if not masks:
masks = self.masks
if not images:
images = self.stacked_scans_data

args = [
(
masks[j], images[j],
)
for j in range(len(images))
]

with multiprocessing.Pool(processes=max(1, multiprocessing.cpu_count() - 2)) as pool:
results = pool.map(crop_single_image, args)

# Flatten results
image_crops, mask_crops, centers = [], [], []
for img_crops, msk_crops, ctrs in results:
image_crops.extend(img_crops)
mask_crops.extend(msk_crops)
centers.extend(ctrs)

del self.image_data
del self.composite_data
del self.masks
del self.stacked_scans_data

return (
[ np.transpose(np.stack(image_crops, axis = 0), (0,3,1,2)), # Convert to (N, C, H, W,) because thats what the current extration model expects,
# it is probaly worth a look at why that choice was made and if it can be undone
binary_masks(np.stack((mask_crops), axis=0)),
np.stack(centers, axis=0)]
)

def load_data(self, image_dir) -> np.ndarray:
"""
Load images from the specified directory, and return a list of images as numpy arrays.
The returned value is optional to use and self.image_data is what the segment wants to use unless overwritten

# return np.array(masks).astype(bool).astype(np.uint8)*255 # binarize the masks for visual check
Args:
image_dir(Path): os-valid path (Use pathlib.Path) for the folder with slide data

return masks
"""
image_files = sorted(os.listdir(image_dir)) # list index must match the order of scans

with multiprocessing.Pool(multiprocessing.cpu_count() - 2) as p: # save one core for the system and one more for good luck
args = [(image_dir, f) for f in image_files]
frames = p.map(load_img, args)

self.image_data = np.array(frames, dtype=np.uint16)
return self.image_data

def save_masks(self, masks) -> None:
if not self.config.mask_output_dir.exists():
self.config.mask_output_dir.mkdir(parents=True, exist_ok=True)

for i, mask in enumerate(masks):
mask_path = self.config.mask_output_dir / f"mask_{i}.png"

# Handle mask format
if mask.dtype == bool or mask.max() <= 1:
mask_to_save = (mask * 255).astype(np.uint8)
else:
mask_to_save = mask.astype(np.uint8)

success = cv2.imwrite(str(mask_path), mask_to_save)
if not success:
print(f"Warning: Failed to save mask {i} to {mask_path}")
12 changes: 12 additions & 0 deletions src/deep_learning/utils/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from dataclasses import dataclass
from pathlib import Path

@dataclass
class Config:
pretrained_model: Path
device: str
data_dir: Path
image_extension: str
mask_output_dir: Path
offset: int = 10

Loading