<a href="https://colab.research.google.com/github/WinetraubLab/3d-segmentation/blob/emilie-8-1-2025/3D-segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<a href="https://colab.research.google.com/github/WinetraubLab/3D-segmentation/blob/main/3D-segmentation.ipynb" target="_blank">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab"/></a>
<a href="https://github.com/WinetraubLab/3D-segmentation/blob/main/3D-segmentation.ipynb" target="_blank">
  <img src="https://img.shields.io/badge/view%20in-GitHub-blue" alt="View in GitHub"/>
</a>

#   Segmentation with MedSAM2
Use MedSAM2 to automatically detect classes and segment a stack of OCT images.

Make sure to use a GPU runtime (T4 on Colab).
> **Runtime → Change runtime type → GPU**  

INPUTS:
1. Folder containing an image sequence of all OCT images and COCO json file downloaded from Roboflow.

 ** Make sure the very first and last images contain all segmented classes **

OUTPUTS:
1. A folder containing segmentation mask images for each frame.
2. COCO json file of segmentations


## Setup and Dependencies

In [None]:
!git clone -b emilie-8-1-2025 https://github.com/WinetraubLab/3D-segmentation.git
!pip install -r 3D-segmentation/requirements.txt

In [None]:
!git clone https://github.com/bowang-lab/MedSAM2.git
%cd MedSAM2
!sh download.sh

In [None]:
# @title Configuration and Dataset
import os
from google.colab import files
from google.colab import drive
import numpy as np
import cv2
import shutil
from PIL import Image

import sys
sys.path.append('/content/3D-segmentation')

import import_data_from_roboflow, propagate_mask_medsam2, export_coco, analyze_volumes

drive.mount('/content/drive')

# LOAD SEGMENTATION DATA
# @markdown Enter the directory containing your image stack to segment. This folder should also contain a COCO-style annotations JSON file if you are not loading segmentations from a dataset hosted on Roboflow. [Click here](https://docs.google.com/document/d/1jWf6Geef_1qd8TU8pD0xVAMV31jp7ab7GkvzoACwI70) for instructions on how to set up this folder if your data is hosted on Roboflow.
image_dataset_and_annotations_folder_path = "/content/drive/Shareddrives/Yolab - Current Projects/_Datasets/2025-05-10 Automatic Segmentation/OCT_sequence" # @param {type:"string"}

if not os.path.isdir(image_dataset_and_annotations_folder_path):
    raise NotADirectoryError(f"‘{image_dataset_and_annotations_folder_path}’ is not a valid directory")

# Load image files and annotations from folder, if provided
image_dataset_paths = {
    "json": [],
    "images": []
}

for filename in os.listdir(image_dataset_and_annotations_folder_path):
    if os.path.isfile(os.path.join(image_dataset_and_annotations_folder_path, filename)):
        if filename.lower().endswith('.json'):
            image_dataset_paths["json"].append(filename)
        else:
            image_dataset_paths["images"].append(filename)

assert len(image_dataset_paths['json'])>0, f"COCO annotation file missing from folder {image_dataset_and_annotations_folder_path}"
assert len(image_dataset_paths['json'])<2, "You may have multiple annotations files. Please consolidate them into a single file."

# @markdown If your images are very large (larger than 750x750), enter the factor by which to downsize them. 1 = no downsizing.
downsample_factor = 1  # @param {type:"slider", min:1, max:10, step:0.5}

class_ids = import_data_from_roboflow.init_from_folder(image_dataset_and_annotations_folder_path)

MODEL_CONFIG = "configs/sam2.1_hiera_t512.yaml"
MODEL_CHECKPOINT = "checkpoints/MedSAM2_latest.pt"

In [None]:
# @title Initialize and run model

# load an image to check size
image_files = image_dataset_paths['images']

if not image_files:
    print("Check image_dataset_folder_path: no images found")
else:
    image_path = os.path.join(image_dataset_and_annotations_folder_path, image_files[0])
    image = Image.open(image_path)
    image_array = np.array(image)
    downsample_2d_size = np.array(image_array.shape[:2]) // downsample_factor

# Preprocess images
preprocessed_images_path = "/content/preprocessed_images/"
import_data_from_roboflow.preprocess_images(image_dataset_and_annotations_folder_path, preprocessed_images_path, downsample_hw_size=downsample_2d_size)

# Run model
model = propagate_mask_medsam2.CustomMEDSAM2(MODEL_CONFIG, MODEL_CHECKPOINT)

indiv_class_masks = []
frame_names = import_data_from_roboflow.list_all_images()
binary_segmentations = np.empty(len(frame_names), dtype=object)
binary_segmentations[:] = None

for class_id in class_ids:
    # construct segmentations for this class
    binary_segmentations = import_data_from_roboflow.create_mask_volume(class_id, downsample_hw_size=downsample_2d_size)
    class_mask = model.propagate(preprocessed_images_path, binary_segmentations, sigma_xy=1/downsample_factor, sigma_z=1/downsample_factor)
    indiv_class_masks.append(class_mask)

# List all classes and ids for user reference in future cells
import_data_from_roboflow.get_all_class_name_id()

Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.
frame loading (JPEG): 100%|██████████| 134/134 [00:01<00:00, 81.51it/s]
/content/MedSAM2/sam2/_C.so: undefined symbol: _ZN3c106detail23torchInternalAssertFailEPKcS2_jS2_RKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE

Skipping the post-processing step due to the error above. You can still use SAM 2 and it's OK to ignore the error above, although some post-processing functionality may be limited (which doesn't affect the results in most cases; see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).
frame loading (JPEG): 100%|██████████| 134/134 [00:01<00:00, 86.20it/s]
frame loading (JPEG): 100%|██████████| 134/134 [00:02<00:00, 53.65it/s]
frame loading (JPEG): 100%|██████████| 134/134 [00:02<00:00, 57.33it/s]
frame loading (JPEG): 100%|██████████| 134/134 [00:01<00:00, 74.34it/s]
frame loading (JPEG): 100%|██████████| 134/134 [00:01<00:00, 85.00it/s]


In [None]:
#@title Downsample Z
downsampled_indiv_class_masks = []
for i in indiv_class_masks:
    Z, Y, X = i.shape
    new_Z = Z // downsample_factor

    # Trim so Z is divisible by 4
    trimmed = i[:new_Z * downsample_factor]

    # Reshape and average every 4 slices
    downsampled_volume = trimmed.reshape(new_Z, downsample_factor, Y, X).mean(axis=1)
    downsampled_indiv_class_masks.append(downsampled_volume)

indiv_class_masks = downsampled_indiv_class_masks

In [None]:
# @title Combine class masks
output_dir = "/content/final_masks/"
propagate_mask_medsam2.combine_class_masks(indiv_class_masks, output_dir=output_dir, show=True)

# Save as COCO annotation file
export_coco.save_segmentations_as_coco(indiv_class_masks, coco_output_dir="predicted_segmentations_coco.json")
files.download("predicted_segmentations_coco.json")
# Save TIFF
export_coco.coco_to_tiff("predicted_segmentations_coco.json", "output_volume.tiff")

In [None]:
shutil.make_archive('/content/final_masks', 'zip', '/content/final_masks')
# files.download("/content/final_masks.zip")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

### Save distance thresholded heatmap
The below code generates and downloads a volumetric heatmap that highlights regions of tissue based on their proximity to a segmented class of interest-- in this example, blood vessels.

Specifically, this example visualizes areas that are closer than 50 µm or farther than 200 µm from any blood vessel, providing spatial context for proximity-based analysis.

In [None]:
heatmap_volume = analyze_volumes.generate_distance_heatmap(indiv_class_masks[0],
                          distance_threshold_px_near = int(50/downsample_factor), distance_threshold_px_far = int(200/downsample_factor), overlay=True, show=True,
                          output_path="/content/distance_threshold_vol.tiff", use_2d_xy_distances=False)

# Save as individual images in a sequence for 3D Viewer
os.makedirs("/content/heatmap_volume_images", exist_ok=True)
for i,v in enumerate(heatmap_volume):
    filename = os.path.join("/content/heatmap_volume_images", f"{i}.tif")
    v = cv2.cvtColor(v, cv2.COLOR_RGB2BGR)
    cv2.imwrite(filename, v)
shutil.make_archive('/content/heatmap_volume_images', 'zip', '/content/heatmap_volume_images')
# files.download("/content/heatmap_volume_images.zip")

# create coco json file of heatmap
analyze_volumes.export_near_far_regions_as_coco(
    mask_volume=indiv_class_masks[0],
    distance_threshold_px_near=int(50/downsample_factor),
    distance_threshold_px_far=int(200/downsample_factor),
    output_path="/content/near_far_regions_coco.json",
    use_2d_xy_distances=False
)
files.download("/content/near_far_regions_coco.json")

### Proximity volume between multiple classes
The below code generates a volumetric heatmap that highlights regions of tissue based on their proximity to multiple segmented class of interest-- in this example, axon bundles and tumors. Specifically, this example visualizes areas that are closer than 50 µm from both axon bundles and tumors.

In [None]:
proximity_volume = analyze_volumes.regions_close_to_object_types(indiv_class_masks[1:], thresh=int(50/downsample_factor))