# **Image segmentation – Basics**

<div style="color:#777777;margin-top: -15px;">
<b>Author</b>: Norman Juchler |
<b>Course</b>: MSLS CO4 |
<b>Version</b>: v1.2 <br><br>
<!-- Date: 16.04.2025 -->
<!-- Comments: Text refactored -->
</div>

In this notebook on segmentation, we will explore different approaches to segment hematological images. As a first step, we will attempt to segment the cells using simple thresholding techniques.

Several of the concepts discussed here are also covered in this insightful tutorial for the ImageJ/Fiji plugin [MorphoLibJ](https://imagej.net/plugins/morpholibj), which you may find helpful for further reference.


---

## **Preparations**

Let's begin with the usual preparatory steps...

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cv2 as cv
import PIL
from pathlib import Path

# Jupyter / IPython configuration:
# Automatically reload modules when modified
%load_ext autoreload
%autoreload 2

# Enable vectorized output (for nicer plots)
%config InlineBackend.figure_formats = ["svg"]

# Inline backend configuration
%matplotlib inline

# Enable this line if you want to use the interactive widgets
# It requires the ipympl package to be installed.
#%matplotlib widget

import sys
sys.path.insert(0, "../")
import tools

We will use the same images that were used in the previous notebook on preprocessing:

In [None]:
# Read in the data
img1 = cv.imread("../data/images/hematology-baso1.jpg", cv.IMREAD_COLOR)
img2 = cv.imread("../data/images/hematology-baso2.jpg", cv.IMREAD_COLOR)
img3 = cv.imread("../data/images/hematology-blast1.jpg", cv.IMREAD_COLOR)

img1 = cv.cvtColor(img1, cv.COLOR_BGR2RGB)
img2 = cv.cvtColor(img2, cv.COLOR_BGR2RGB)
img3 = cv.cvtColor(img3, cv.COLOR_BGR2RGB)

tools.show_image_chain([img1, img2, img3], titles=["img1", "img2", "img3"])

---

## **Method 1: Thresholding**

We can segment images using basic thresholding techniques. In this example, we explore several thresholding methods available in OpenCV:

- **Simple thresholding**: Use [`cv.threshold()`](https://docs.opencv.org/4.x/d7/d1b/group__imgproc__misc.html#gae8a4a146d1ca78c626a53577199e9c57)  
  (with flags `cv.THRESH_BINARY` or `cv.THRESH_BINARY_INV`)
- **Adaptive thresholding**: Use [`cv.adaptiveThreshold()`](https://docs.opencv.org/4.x/d7/d1b/group__imgproc__misc.html#ga72b913f352e4a1b1b397736707afcde3)
- **Otsu's thresholding** : Use [`cv.threshold()`](https://docs.opencv.org/4.x/d7/d1b/group__imgproc__misc.html#gae8a4a146d1ca78c626a53577199e9c57)  
  (with flags `cv.THRESH_BINARY + cv.THRESH_OTSU`)

Thresholding segments pixels into foreground and background based on their intensity values, making it a form of *binary* segmentation*. The algorithm compares pixel intensities to a threshold value: Pixel values larger than the threshold are classified as *foreground*, pixels smaller or equal than the threshold are *background*.

The threshold can be manually defined or automatically determined (e.g., by Otsu’s method).

As preparation, please review the following OpenCV documentation on thresholding methods:  
[https://docs.opencv.org/4.x/d7/d4d/tutorial_py_thresholding.html](https://docs.opencv.org/4.x/d7/d4d/tutorial_py_thresholding.html)


In [None]:
######################
###    EXCERISE    ###
######################

# Choose here the image to work with
img = img1

# 1) Summarize the three different thresholding techniques in own words. Which methods
#    use a global threshold, which ones apply a local threshold? 

# 2) Develop a strategy to segment the white blood cells (purple), the red blood cells 
# (red) and the background (white/gray). You may want to exploit the fact that we have
# colors to work with:
tools.show_image_chain([img[:,:,0], img[:,:,1], img[:,:,2]], titles=["R", "G", "B"])

# 3) Identify the different regions using thresholding
mask_wbc = ...
mask_rbc = ...
mask_bg = ...

# 4) Visualize the masks. Idea: combine the three masks into an RGB image
mask_seg = ...

# 5) Discuss your results. What could be improved? What are the limitations of this 
#    approach? Are the masks mutually exclusive? Are they accurate?



In [None]:
######################
###    SOLUTION    ###
######################

# 1) cv.threshold:              Apply a global threshold to the image.
#    cv.adaptiveThreshold:      Apply a local threshold to the image.
#    cv.threshold (otsu):       Apply a global threshold to the image using Otsu's method.

# 2) Segmentation strategy:
#    The information in the three different channels suggests that we can use the
#    red channel to segment the red blood cells, the blue channel to segment the white
#    blood cells and the luminance channel (or gray channel) to segment the background. 
#    We can then combine the three masks to obtain the final segmentation.
# 
#    Display the 3 channels. See note below as to why we disable normalization.
img = img1
tools.show_image_chain([img[:,:,0], img[:,:,1], img[:,:,2]], 
                       titles=["R", "G", "B"], normalize=False)

# 3) Let's try how this works in practice
def segment_blood_cells_thr(img, return_masks=False):

    # Smooth the image (to reduce noise)
    #img = cv.GaussianBlur(img, (5, 5), 0)

    # -> Convert image to grayscale
    gray = cv.cvtColor(img, cv.COLOR_RGB2GRAY)

    # -> Extract background using Otsu's method
    thr_bg, mask_bg = cv.threshold(gray, 0, 255, cv.THRESH_BINARY+cv.THRESH_OTSU)

    # -> Extract red and white blood cells
    thr_rbc, mask_rbc = cv.threshold(img[:,:,0], 130, 255, cv.THRESH_BINARY)
    thr_wbc, mask_wbc = cv.threshold(img[:,:,2], 140, 255, cv.THRESH_BINARY)

    # -> Apply the segmentation logic: 
    #     - First observe that the background takes high values in the red and blue
    #       channels. Thresholding the red and blue channels will also include the 
    #       background. Furthermore, the blue component sometimes is also present 
    #       in the red blood cells. Therefore:
    #     - Exclude the background from the masks for the red and white blood cells
    #     - Exclude the red blood cells from the white blood cells mask. Here,
    #       we use the condition that something appears purple if the blue
    #       channel is significantly higher than the red channel.

    mask_rbc = mask_rbc.astype(bool) & ~mask_bg.astype(bool)
    mask_wbc = mask_wbc.astype(bool) & ~mask_bg.astype(bool)
    mask_wbc = mask_wbc & (img[:,:,0]*1.1 < img[:,:,2])

    # Combine the information into a color image.
    result = np.ones_like(img) * 255
    result[mask_rbc.astype(bool)] = [155, 107, 132]
    result[mask_wbc.astype(bool)] = [62, 32, 152]

    if return_masks:
        return result, mask_bg, mask_rbc, mask_wbc
    else:
        return result


# Compute the segmentation and viusalize the results
img = img1
ret = segment_blood_cells_thr(img, return_masks=True)
result1, mask_bg, mask_rbc, mask_wbc = ret

# Visualize the masks
tools.show_image_chain([mask_bg, mask_rbc, mask_wbc], 
                       titles=["Background", "RBC", "WBC"])
# Visualize the results
tools.show_image_chain([img, result1], titles=["Input: img1", "Output: Segmentation"]);


The segmentation looks fairly good, but it is not perfect. The main limitations are:  

- The thresholding is very sensitive to the threshold values. Small changes in the threshold can lead to very different results.  
- The assumptions and segmentation logic are tailored to the specific images and may not generalize well.  
- The masks contain holes and do not segment the cells precisely – this is a limitation of the thresholding approach used.  
- Boundary effects are visible (e.g., in red blood cells), indicating that the method is not fully reliable.  
- Nearby or overlapping cells may not be distinguishable in the segmentation. This issue is clearly visible in the result for `img3` (see below).

We can refine the results by...
- ...tuning the thresholding parameters  
- ...improving the segmentation logic  
- ...smoothing the image (see the commented-out line of code above)  
- ...using *morphological operations* to close holes and remove noise (see next tutorial)

Note: When displaying images, it is helpful to disable *normalization*, which automatically stretches pixel values to the full range [0, 255]. This normalization is the default behavior of [`plt.imshow()`](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.imshow.html), which is used internally by our helper function `tools.show_image_chain()`. To accurately inspect grayscale values (e.g., using a color picker), it is better to view the image with its original intensity values.

A **color picker** is a tool that shows the color value under the mouse pointer. 
- *macOS*: [Digital Color Meter](https://support.apple.com/guide/digital-color-meter/welcome/mac) (pre-installed under /System/Applications/Utilities/)  
- *Windows:* *Color Picker* as part of the [PowerToys](https://learn.microsoft.com/en-us/windows/powertoys/)  
- *Ubuntu:* [Gpick](https://www.gpick.org/)  

These tools support multiple color formats and let you copy color values to the clipboard, etc.


In [None]:
# Display the segmentation results also for the other images:
result2 = segment_blood_cells_thr(img2)
tools.show_image_chain([img2, result2], titles=["Input: img2", "Output: Segmentation"]);
result3 = segment_blood_cells_thr(img3)
tools.show_image_chain([img3, result3], titles=["Input: img3", "Output: Segmentation"]);

---


## **Method 2: Color clustering**

Instead of segmenting the image into foreground and background, we can attempt to classify different regions based on color similarity. A common approach for this is the [K-means clustering](https://en.wikipedia.org/wiki/K-means_clustering) algorithm. K-means classifies pixels into a predefined number of clusters based on their color values. Similarity is typically measured using a (Euclidean or non-Euclidean) distance between pixel values. The algorithm operates iteratively:

1. Assign each pixel to the nearest cluster center.
2. Update each cluster center as the mean of the pixels assigned to it.
3. Repeat until the cluster centers converge.

Here is a helpful [visualization](https://www.naftaliharris.com/blog/visualizing-k-means-clustering/) of how K-means clustering works.


**Preparation:** Before you begin, check out these two tutorials:
- Jason Brownlee (Machine Learning Mastery) on [color quantization with K-means](https://machinelearningmastery.com/k-means-clustering-in-opencv-and-application-for-color-quantization/)
- Shubhang Agrawal on [image segmentation using K-means clustering](https://medium.com/swlh/image-segmentation-using-k-means-clustering-46a60488ae71). (The tutorial has a few flaws, please excuse). 


<!-- 
Resources:
# Nice way of depicting the bars
https://pyimagesearch.com/2014/05/26/opencv-python-k-means-color-clustering/
# OpenCV
https://docs.opencv.org/3.4/d1/d5c/tutorial_py_kmeans_opencv.html
# Machine Learning Mastery
https://machinelearningmastery.com/k-means-clustering-in-opencv-and-application-for-color-quantization/
# Watershed
https://docs.opencv.org/4.x/d3/db4/tutorial_py_watershed.html
# Segmentation with Skimage 
https://github.com/ipython-books/cookbook-2nd-code/blob/master/chapter11_image/03_segmentation.ipynb
# Combination between thresholding and color clustering
https://towardsdatascience.com/image-color-segmentation-by-k-means-clustering-algorithm-5792e563f26e
-->

In [None]:
######################
###    EXCERISE    ###
######################

# Choose here the image to work with
img = img1

# 1) Reshape the color pixels into a Mx3 matrix (M: number of pixels)
#    and convert the data type to float32.
data = img.reshape(-1, 3).astype(np.float32)

# 2) Apply the K-means algorithm to the data. Use the cv.kmeans function.
#    Choose the number of clusters K=3.
criteria = (cv.TERM_CRITERIA_EPS + cv.TERM_CRITERIA_MAX_ITER, 10, 1.0)
K = 3
ret, label, centers = cv.kmeans(data, K, None, criteria, 10, cv.KMEANS_RANDOM_CENTERS)
# label contains the cluster index for each pixel
# centers contains the cluster centers (colors!)

# 3) Reshape and convert the data back to uint8
img_seg = ...

# 4) Visualize the segmented image
tools.show_image_pair(img, img_seg, title1="Original", title2="Segmented");

# 5) Repeat the process for a different color space (e.g. HSV)
#    Is the clustering more robust? Why? When does this approach fail?


In [None]:
######################
###    SOLUTION    ###
######################
def segment_blood_cells_kmeans(img, K=3, use_lab=False):
    # Blur the image to reduce noise (step is required here 
    # to yield feasible results)
    img = cv.GaussianBlur(img, (5, 5), 0)

    if use_lab:
        img = cv.cvtColor(img, cv.COLOR_RGB2LAB)

    # 1) Reshape the color pixels into a Mx3 matrix (M: number of pixels)
    #    and convert the data type to float32.
    data = img.reshape(-1, 3).astype(np.float32)
    
    # 2) Apply the K-means algorithm to the data. Use the cv.kmeans function.
    # Some parameters for the kmeans algorithm (termination criteria):
    criteria = (cv.TERM_CRITERIA_EPS + cv.TERM_CRITERIA_MAX_ITER, 10, 0.1)
    ret, label, centers = cv.kmeans(data, 
                                    K=K, 
                                    bestLabels=None, 
                                    criteria=criteria, 
                                    attempts=10, 
                                    flags=cv.KMEANS_PP_CENTERS)
    # label contains the cluster index for each pixel
    # centers contains the cluster centers (colors!)

    # 3) Reshape and convert the data back to uint8
    img_seg = centers[label.flatten()].reshape(img.shape).astype(np.uint8)

    if use_lab:
        img_seg = cv.cvtColor(img_seg, cv.COLOR_LAB2RGB)

    # 4) Return the segmented image
    return img_seg


img = img1
img_seg = segment_blood_cells_kmeans(img, K=3, use_lab=False)
tools.show_image_chain([img, img_seg], titles=["Original", "Segmented"]);


The result here looks somewhat better than in the previous example; however, the segmentation is still not perfect:

- The effectiveness of clustering depends on the prevalence of colors in the image. If the background contains many subtle color variations, K-means may allocate more clusters to the background rather than focusing on the blood cells.
- The K-means algorithm is sensitive to the initialization of cluster centers. Poor initialization can lead to convergence to a local minimum, resulting in suboptimal clustering.
- The method still struggles with touching or overlapping cells. While the cells may be assigned the correct color, they are not individually separated.

We may improve the results by switching to a different color space. For instance, the LAB color space is more robust to illumination changes and better reflects human color perception. Another strategy is to allow the algorithm to identify more clusters, and then apply post-processing to merge similar clusters. For example, if several cluster centers have a dominant blue component, they could be combined into a single class.


---

## **Method 3: Watershed algorithm for segmentation**

The watershed algorithm is a powerful tool for image segmentation, particularly useful for separating complex or overlapping structures. It is based on the concept of watershed lines, which define boundaries between different regions in an image. The algorithm works by conceptually "flooding" the image from predefined seed points. As water spreads from each seed region, it continues flowing until it meets water from a neighboring region.  
The boundaries where the regions meet are defined as the watershed lines, effectively separating the regions. Watershed segmentation can be applied based on pixel intensity or color and is especially effective for images with intricate shapes and touching objects.

Our dataset has structural similarities to the image used in this [OpenCV watershed tutorial](https://docs.opencv.org/4.x/d3/db4/tutorial_py_watershed.html). We will now follow a simplified version of the tutorial to segment our images using the watershed method. The tutorial uses the following strategy:

### **Overview / Steps**

1. Convert the image to a binary mask using thresholding.
2. Apply morphological operations to reduce noise and help separate objects. These operations are also used to identify regions that likely represent the background.
3. Identify seed points for the watershed algorithm:
   - Apply the distance transform to the binary mask. This computes, for each pixel, the distance to the nearest background pixel (value 0).
   - Threshold the distance-transformed image to isolate blobs near the centers of objects of interest.
   - Use the connected components algorithm to label and enumerate the seed regions, creating a labeled mask.
   - Mark the background seed region (identified in step 2) with label `0`.
4. Apply the watershed algorithm to segment the regions.
5. Visualize the resulting segmentation.



### **Note: Morphological operations**

Morphological operations are a set of operations used to analyze and manipulate the shape of objects in an image. Although they are defined for various image types, they are most commonly applied to binary images. These operations involve a structuring element (kernel) that probes the image and modifies pixel values based on the interaction between the kernel and the image. The most common operations include dilation (expands shapes), erosion (shrinks shapes), opening (erosion followed by dilation), and closing (dilation followed by erosion). Morphological operations are useful for removing noise, separating objects, and connecting disjoint regions in an image. There is a separate notebook on morphological operations.

**Further reading**:  
- OpenCV documentation: [Link](https://docs.opencv.org/4.x/d9/d61/tutorial_py_morphological_ops.html)  
- Beautiful illustration of morphological operations: [Link](https://penny-xu.github.io/blog/mathematical-morphology)  
- Wikipedia article on mathematical morphology: [Link](https://en.wikipedia.org/wiki/Mathematical_morphology)  
- Blog post on morphological operations: [Link](https://towardsdatascience.com/7bcf1ed11756)

### **Note: Distance transform**
The distance transform is useful for various image processing tasks. It computes the distance from each pixel to the nearest boundary (i.e., the closest background pixel) in a binary image. Distance transforms are used for operations like skeletonization, shape analysis, and segmentation. The algorithm propagates distance values from boundary pixels inward, typically using a metric such as the Euclidean distance. It is computationally efficient and widely available in image processing libraries.

**Further reading**:  
- Application of distance transform with watershed: [Link](https://docs.opencv.org/3.4/d2/dbd/tutorial_distance_transform.html)

### **Note: Connected components**

Connected components are regions in a binary image where pixels are connected based on predefined neighborhood rules (e.g., 4-connectivity or 8-connectivity). This technique is used to identify individual objects in a segmentation mask. The algorithm labels each connected region with a unique integer value, allowing for further analysis such as counting or measuring properties of each component.

**Further reading**:  
- Wikipedia article on connected component labeling: [Link](https://en.wikipedia.org/wiki/Connected-component_labeling)



In [None]:
######################
###    EXCERISE    ###
######################

img = img1

# Implement the approach lined out above. You can copy paste the code 
# from the above link and plug our image into it. Try to understand the
# code and the different steps. You may have to adjust the parameters
# to get a good segmentation result.

# https://docs.opencv.org/4.x/d3/db4/tutorial_py_watershed.html

...

In [None]:
######################
###    SOLUTION    ###
######################
def segment_red_blood_cells_watershed(img):
    
    img = cv.GaussianBlur(img, (5, 5), 0)
    img_blur = cv.medianBlur(img, 5)
    
    #tools.show_image(img_blur)
    
    gray = cv.cvtColor(img_blur, cv.COLOR_RGB2GRAY)
    gray = img_blur[:,:,0]
    ret, thresh = cv.threshold(gray, 0, 255, cv.THRESH_BINARY_INV + cv.THRESH_OTSU)

    # Noise removal
    kernel = np.ones((3, 3), np.uint8)
    opening = cv.morphologyEx(thresh, cv.MORPH_OPEN, kernel, iterations=9)

    # Sure background area
    sure_bg = cv.dilate(opening, kernel, iterations=3)

    # Finding sure foreground area
    dist_transform = cv.distanceTransform(opening, cv.DIST_L2, 5)
    thr = 0.1 * dist_transform.max()
    thr = 18
    ret, sure_fg = cv.threshold(dist_transform, thr, 255, 0)
    
    tools.show_image_chain([sure_fg, sure_bg], titles=["Sure FG", "Sure BG"])
    tools.show_image_chain([opening, dist_transform], titles=["Opening", "Distance transform"])

    # Finding unknown region
    sure_fg = np.uint8(sure_fg)
    unknown = cv.subtract(sure_bg, sure_fg)
    
    # Marker labelling
    ret, markers = cv.connectedComponents(sure_fg)
    
    # Add one to all labels so that sure background is not 0, but 1
    markers = markers+1
    
    # Now, mark the region of unknown with zero
    markers[unknown==255] = 0
    
    markers = cv.watershed(img,markers)
    img[markers == -1] = [255,0,0]
    
    return markers, img
    
    
img = img1.copy()
markers, result = segment_red_blood_cells_watershed(img=img)
tools.show_image_chain([markers, result], 
                       titles=["Markers", "Segmented"])

---

## **AI driven segmentation**
Deep learning is increasingly used for image segmentation tasks, with the U-Net architecture still being one of the most popular choices. U-Net is a convolutional neural network specifically designed for biomedical image segmentation. 

A specialized U-Net-based tool for medical imaging is [nnU-Net](https://github.com/MIC-DKFZ/nnUNet). It features self-configuring preprocessing and postprocessing, allowing the network to automatically adapt to the characteristics of the input data. nnU-Net is available as a Python package and can be installed via pip.

Although machine learning and AI are not the core focus of this course, pre-trained models can still be applied to perform segmentation effectively. Unlike classical methods, deep learning models can learn features directly from the data and often generalize better to unseen examples. However, they require large labeled datasets for training, are more computationally demanding, and are often seen as "black boxes" – making it difficult to interpret their decisions.


```python
######################
###    EXCERISE    ###
######################
```

Visit the following resources and explore whether they could be useful for your own segmentation project:

- **Segment Anything** by Meta AI [Demo](https://segment-anything.com/demo), [Paper](https://arxiv.org/abs/2304.02643), [Code](https://github.com/facebookresearch/segment-anything) 
- **Huggingface**: Collection of public pre-trained models. [Link](https://huggingface.co/models).
  - Many models include a demo interface
  - Background removal with [RemBG](https://huggingface.co/spaces/KenjieDec/RemBG)
  - Another popular segmentation tool is [YOLO](https://huggingface.co/spaces/fcakyon/yolov8-segmentation) ([Code](https://huggingface.co/spaces/fcakyon/yolov8-segmentation))
  - To search the entire Huggingface database for models: [Link](https://huggingface.co/models)
- **TotalSegmentator** for anatomical CT (and MR) segmentation. [Demo](https://totalsegmentator.com/), [Paper](https://arxiv.org/abs/2208.05868), [Code](https://github.com/wasserth/TotalSegmentator)


We have now explored several approaches to image segmentation.  How well you can apply them will depend on your specific problem – and a bit of engineering skill. 😊


In [None]:
######################
###    EXERCISE    ###
######################

# Try using one of the models listed above to segment the cells in the image.

In [None]:
######################
###    SOLUTION    ###
######################

# Let's use the Segment Anything model from Meta.
# The following lines may take a while to execute.
try:
    from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
except ImportError:
    print("Installing the model...")
    !pip install -q git+https://github.com/facebookresearch/segment-anything.git
    !pip install -q opencv-python pycocotools matplotlib onnxruntime onnx
    !pip install -q torch torchvision

# Download the model (if not available yet)
path_to_checkpoint="./sam_vit_h_4b8939.pth"
path_to_checkpoint="/Users/juch/workspace/education/phd/data/models/sam_vit_h_4b8939.pth"
url_checkpoint="https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
if not Path(path_to_checkpoint).exists():
    print("Downloading model... This may take a while!")
    !wget -O {path_to_checkpoint} {url_checkpoint}

# Load model
print("Loading model...")
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
sam = sam_model_registry["vit_h"](checkpoint=path_to_checkpoint)
mask_generator = SamAutomaticMaskGenerator(sam)

In [None]:
def blend(img, overlay):
    """Blend an image with an overlay."""
    alpha = overlay[:,:,3]
    img = img.astype(np.float32)
    result = ((1 - alpha[:, :, None]) * img + 
              alpha[:, :, None] * overlay[:, :, :3] * 255)
    result = result / result.max()
    return result

In [None]:
def segment_blood_cells_sam(img, masks):
    result = img.copy()
    overlay_color = [255, 255, 0]
    alpha = 0.2
    clean_masks = True

    # Sort masks by area
    masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)

    if clean_masks:
        # Filter masks that are fully mask
        to_remove = []
        for i, m1 in enumerate(masks):
            m1 = m1["segmentation"]
            for j in range(i+1, len(masks)):
                m2 = masks[j]["segmentation"]
                if i in to_remove or j in to_remove:
                    continue
                if (m1.sum()) ==( (m1 | m2).sum()):
                    to_remove.append(j)
        masks = [m for i, m in enumerate(masks) if i not in to_remove]
        
    # Check the type of cell, using the following heuristic:
    # If the mask is mostly red, it is a red blood cell, if
    # it is mostly blue, it is a white blood cell
    for m in masks:
        mask = m["segmentation"]
        r = img[:,:,0][mask].mean()
        b = img[:,:,2][mask].mean()
        m["type"] = "rbc" if r > b else "wbc"
        
    # Visualize the masks
    result = np.ones((img.shape[0], img.shape[1], 4))
    result[:,:,3] = 0
    for m in masks:
        cell = m["type"]
        m = m["segmentation"]
        contours, hierarchy = cv.findContours(m.astype(np.uint8)*255, 
                                                cv.RETR_TREE, 
                                                cv.CHAIN_APPROX_SIMPLE)
        # Random numbers for the color
        rr, rb, rg = np.random.random(3)*0.5
        overlay_color = ([255/255, rb, rg, alpha] if (cell == "rbc") 
                         else [rr, rg, 255/255, alpha])
        result[m] = overlay_color
        overlay_color[3] = 1
        cv.drawContours(result, contours, -1, overlay_color, 2)
        
    result = blend(img, result)
    tools.show_image_chain([img, result], titles=["Input", "Output"])

In [None]:
# Choose image here:
img = img1

# Generate the masks (this may take a while, about 30s)
# Keeping this call outside the segment* function because 
# this step takes a while.
print("Generating mask... This may take a while!")
masks = mask_generator.generate(img)

# Visualize the masks
segment_blood_cells_sam(img, masks)

In [None]:
# Create a binary mask for the red blood cells and save it to a file
mask = sum([m["segmentation"] for m in masks if m.get("type") == "rbc"])
# Apply some morphological opening to clean the mask
mask = cv.morphologyEx(mask.astype(np.uint8), cv.MORPH_OPEN, 
                       np.ones((3, 3), np.uint8), iterations=1)
tools.show_image(mask.astype(np.uint8)*255, title=None, suppress_info=True)
cv.imwrite("mask_rbc.png", mask.astype(np.uint8)*255)

The result looks good – the model segments the different cells with high accuracy. This is impressive given that the Segment Anything model was trained on a general-purpose dataset rather than on hematologic images. This suggests that Anything model generalizes well to new domains.

In [None]:
# Repeat for img2
img = img2
masks = mask_generator.generate(img)
segment_blood_cells_sam(img, masks)

# Repeat for img3
img = img3
masks = mask_generator.generate(img)
segment_blood_cells_sam(img, masks)