<a href="https://colab.research.google.com/github/WinetraubLab/3D-segmentation/blob/main/3D-segmentation.ipynb" target="_blank">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab"/></a>
<a href="https://github.com/WinetraubLab/3D-segmentation/blob/main/3D-segmentation.ipynb" target="_blank">
  <img src="https://img.shields.io/badge/view%20in-GitHub-blue" alt="View in GitHub"/>
</a>

#   Segmentation with MedSAM2
Use MedSAM2 to automatically detect classes and segment a stack of OCT images.

Make sure to use a GPU runtime (T4 on Colab).
> **Runtime → Change runtime type → GPU**  

INPUTS:
1. Roboflow dataset of OCT images with annotations. Each annotation mask is a complete segmentation of one instance of the specified class.
2. Folder containing OCT images to be segmented using this script.

OUTPUTS: A folder containing segmentation mask images for each frame.


## Setup and Dependencies

In [None]:
!git clone https://github.com/WinetraubLab/3D-segmentation.git
!pip install -r 3D-segmentation/requirements.txt

In [None]:
!git clone https://github.com/bowang-lab/MedSAM2.git
%cd MedSAM2
!sh download.sh

In [None]:
# @title Configuration and Dataset
import os
from roboflow import Roboflow
from google.colab import files
from google.colab import drive
import pandas as pd
import numpy as np

import gspread
from google.auth import default
from google.auth.transport.requests import Request
from google.auth.credentials import AnonymousCredentials

import sys
sys.path.append('/content/3D-segmentation')

import import_data_from_roboflow, propagate_mask_medsam2, export_coco

drive.mount('/content/drive')

# LOAD ROBOFLOW DATA
# @markdown Enter the directory containing your image stack to segment:
image_dataset_folder_path = "/content/drive/Shareddrives/Yolab - Current Projects/_Datasets/2025-05-10 Automatic Segmentation/OCT_sequence" # @param {type:"string"}

if not os.path.isdir(image_dataset_folder_path):
    raise NotADirectoryError(f"‘{image_dataset_folder_path}’ is not a valid directory")

# @markdown **If loading segmentations from Roboflow:** Enter your Roboflow API key, and details of the dataset with your annotated images. Otherwise, leave this blank.
workspace_name = ""  # @param {type:"string"}
project_name = ""  # @param {type:"string"}
# @markdown For example: workspace_name="yolab-kmmfx"; project_name="vol1_2"

# @markdown **If loading segmentations from folder:** Enter the path to the folder containing COCO .json file with annotation data. Otherwise, leave this blank.
segmentation_data_path = "/content/drive/Shareddrives/Yolab - Current Projects/_Datasets/2025-05-10 Automatic Segmentation/Sample_Roboflow_Dataset"  # @param {type:"string"}

if segmentation_data_path:
    class_ids = import_data_from_roboflow.init_from_folder(segmentation_data_path)
else:
    # Get API key from credentials sheet to load data

    # Authenticate using google-auth
    creds, _ = default()
    if creds and creds.expired and creds.refresh_token:
        creds.refresh(Request())
    # Authorize with gspread
    gc = gspread.authorize(creds)

    spreadsheet = gc.open("Credentials & Passwords")
    worksheet = spreadsheet.sheet1
    data = worksheet.get_all_values()
    data = data[1:]
    df = pd.DataFrame(data)
    api_row = df[df.apply(lambda row: row.astype(str).str.contains('Roboflow API Key').any(), axis=1)]
    if not api_row.empty:
        api_key = api_row.iloc[0, 1]
    else:
        print("Could not find API key.")
    class_ids = import_data_from_roboflow.init_from_roboflow(workspace_name, project_name, api_key)

MODEL_CONFIG = "configs/sam2.1_hiera_t512.yaml"
MODEL_CHECKPOINT = "checkpoints/MedSAM2_latest.pt"

In [None]:
# @title Initialize and run model

# Preprocess images
preprocessed_images_path = "/content/preprocessed_images/"
import_data_from_roboflow.preprocess_images(image_dataset_folder_path, preprocessed_images_path)

# Run model
model = propagate_mask_medsam2.CustomMEDSAM2(MODEL_CONFIG, MODEL_CHECKPOINT)

indiv_class_masks = []
frame_names = import_data_from_roboflow.list_all_images()
binary_segmentations = np.empty(len(frame_names), dtype=object)
binary_segmentations[:] = None
h,w = None, None

for class_id in class_ids:
    # construct segmentations for this class
    binary_segmentations = import_data_from_roboflow.create_mask_volume(class_id)
    class_mask = model.propagate(preprocessed_images_path, binary_segmentations, sigma_xy=0, sigma_z=0)
    indiv_class_masks.append(class_mask)

In [None]:
# @title Combine class masks
output_dir = "/content/final_masks/"
propagate_mask_medsam2.combine_class_masks(indiv_class_masks, output_dir=output_dir, coco_output_dir="predicted_segmentations_coco.json", show=True)

# Save TIFF
export_coco.coco_to_tiff("predicted_segmentations_coco.json", "output_volume.tiff")

In [None]:
# @title Download files
from google.colab import files

files.download("output_volume.tiff")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>