In [1]:
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Modified from the original `predictor_example.ipynb`

# Installation

In [None]:
%pip install git+https://github.com/facebookresearch/segment-anything.git

In [None]:
import wget, os

root_dir = os.getenv('DATA_ROOT')
sam_checkpoint = root_dir + "/sam_vit_h_4b8939.pth"
if not os.path.exists(sam_checkpoint):
    wget.download('https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', sam_checkpoint)

# Object masks from prompts with SAM

The Segment Anything Model (SAM) predicts object masks given prompts that indicate the desired object. The model first converts the image into an image embedding that allows high quality masks to be efficiently produced from a prompt. 

The `SamPredictor` class provides an easy interface to the model for prompting the model. It allows the user to first set an image using the `set_image` method, which calculates the necessary image embeddings. Then, prompts can be provided via the `predict` method to efficiently predict masks from those prompts. The model can take as input both point and box prompts, as well as masks from the previous iteration of prediction.

In [1]:
import os
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from monai.transforms import EnsureType, LoadImage, SaveImage, ScaleIntensityRangePercentiles
from monai.utils import set_determinism
from scipy.ndimage import center_of_mass
from scipy.ndimage import label as scipy_label
from segment_anything import SamPredictor, sam_model_registry
from tqdm import tqdm

## Generate pseudo annotations

In [2]:
# first generate 3D nii mask from `COSMOS` xml files
set_determinism(42) # ensure split
dataset = 'COSMOS' # 'COSMOS' or 'careII'
device = 'cuda:0'
root_dir = Path(os.getenv("DATA_ROOT"))

sam_checkpoint = root_dir / "sam_vit_h_4b8939.pth"
sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)

if dataset == 'COSMOS':
    data_dir = root_dir / 'COSMOS'/ 'train_data'
    all_cases = [f for f in os.listdir(data_dir) if f.isdigit()]
else:
    data_dir = root_dir / 'careII'/ 'train_data'
    all_cases = [folder for folder in os.listdir(data_dir) if folder.startswith("0_P")]

## Define centerline extraction functions

In [4]:
def shortest_path(mask: np.ndarray) -> list[np.ndarray]:
    mask = EnsureType("numpy", dtype=np.float32)(mask)
    C, H, W, D = mask.shape
    indice = np.unique(np.where(mask > 0)[-1])  # Get labeled slices idx from 2 channels
    center_matrix = np.zeros((len(indice), 3, 3))
    component_matrix = np.zeros((len(indice), 3, H, W))
    for d, slice_idx in enumerate(indice):
        lumen_mask_slice = mask[0, ..., slice_idx]
        lumen_mask_slice, num_components = scipy_label(lumen_mask_slice)
        for i in range(3):  # top 3 masks
            component = lumen_mask_slice == i + 1
            if np.any(component):
                center = center_of_mass(component)
                center_matrix[d, i] = np.array([center[0], center[1], slice_idx])
            else:
                center_matrix[d, i] = np.array([0, 0, 0])
            component_matrix[d, i] = np.float32(component)

    center_matrix = np.flip(center_matrix, 0)  # from top of brain to neck
    component_matrix = np.flip(component_matrix, 0)
    voxel_center = np.mean(center_matrix[:, 0], axis=0)  # 3,
    shortest_centers = [np.array([voxel_center[0], voxel_center[1], center_matrix[0][0][-1]])]
    shortest_components = [np.zeros_like(component_matrix[0][0])]
    for i in range(len(center_matrix)):
        current_center = shortest_centers[-1]
        start = i
        end = min(i + 50, len(center_matrix))
        centers_next_50 = center_matrix[start:end]
        compos_next_50 = component_matrix[start:end]
        distances = np.sum((centers_next_50 - current_center[None, None, ...]) ** 2, axis=-1)  # 50, 3
        closest_idx = np.array(np.where(distances == distances.min())).T[0]
        shortest_centers.append(centers_next_50[closest_idx[0], closest_idx[1]])
        shortest_components.append(compos_next_50[closest_idx[0], closest_idx[1]])

    shortest_mask = np.zeros_like(mask)
    for center, comp in zip(shortest_centers, shortest_components):
        shortest_mask[0, ..., int(center[-1])] += comp
    shortest_centers = np.flip(np.array(shortest_centers), 0)

    return shortest_centers, shortest_mask


def fix_missing_centers(centers: np.ndarray) -> list[np.ndarray]:
    r"""Use interpolation to predict the missing slice center."""
    idx = 0
    new_centers = []
    for idx in range(len(centers) - 1):
        new_centers.append(centers[idx].astype(np.int32))
        ideal_next = int(centers[idx][-1] + 1)
        real_next = int(centers[idx + 1][-1])
        if real_next != ideal_next:
            step = (centers[idx + 1] - centers[idx]) / (real_next - ideal_next)
            for i in range(ideal_next, real_next):
                new_center = centers[idx] + (i - ideal_next + 1) * step
                new_center[-1] = i
                new_centers.append(np.int32(new_center))
    return new_centers

## Generate

In [None]:
input_dir = os.path.join(root_dir, dataset, 'preprocessed', 'mri_nii_raw')
output_dir = os.path.join(root_dir, dataset, 'preprocessed', 'mri_nii_raw')
image_loader = LoadImage(image_only=True, ensure_channel_first=True)
image_saver = SaveImage(output_dir=output_dir, output_postfix='', separate_folder=False, print_log=True)
for case in all_cases:
    sam_save_dir = os.path.join(output_dir, f'{case}_sam.nii.gz')
    if os.path.exists(sam_save_dir): # skip existed
        continue
    image = image_loader(os.path.join(input_dir, f'{case}_image.nii.gz'))
    sam_mask = torch.zeros_like(image)
    half_width = image.shape[1] // 2
    for i in range(2): # we do sam side by side to avoid one side failure pred ruins entire slice
        side_mask = image_loader(os.path.join(input_dir, f'{case}_mask.nii.gz'))
        side_mask[:, half_width * i: half_width * (i + 1)] = 0 # remove half of the side
        side_mask = np.asarray(side_mask == 1, dtype=np.float32)
        labeled_slices = np.unique(np.where(side_mask > 0)[-1])
        if len(labeled_slices) <= 3:
            print(case, i, 'less than 3 annotated')
        else:
            print(case, i, 'processing')
            image = ScaleIntensityRangePercentiles(0, 98, 0, 1, clip=True)(image)
            center_line, mask = shortest_path(side_mask) # find center line in 3D sparse annotations
            center_line = fix_missing_centers(center_line) # interpolate centerline
            for center in tqdm(center_line):
                center = np.int32(center)
                if center[-1] not in labeled_slices:
                    image_slice = image[0, :, :, center[-1]]
                    image_slice = np.asarray(image_slice*255, dtype=np.uint8)
                    image_slice = cv2.cvtColor(image_slice, cv2.COLOR_GRAY2RGB)                    
                    predictor.set_image(image_slice)

                    input_point = np.array([[center[1], center[0]]]).astype(int)
                    input_label = np.array([1])
                    masks_pred, scores, logits = predictor.predict(
                        point_coords=input_point,
                        point_labels=input_label,
                        multimask_output=True,
                    )
                    mask_pred = (masks_pred[0] > 0).astype(int)
                    mask_pred[half_width * i: half_width * (i + 1)] = 0 # ensure
                    side_mask[0, :, :, center[-1]] += mask_pred
        sam_mask += torch.as_tensor(side_mask)
    sam_mask[sam_mask > 0] = 1
    sam_mask.meta['filename_or_obj'] = sam_save_dir
    image_saver(sam_mask)