In this notebook, we will:

1. Define helper functions for calculating bounding boxes, converting images to 3 channels, cropping images, and plotting comparisons.
2. Load MRI and label data, compute 3D bounding boxes, and derive slice-level bounding boxes.
3. Use a model (assumed to be defined externally as bbox_prompt) to infer a segmentation mask for each slice within the bounding box region.
4. Save the predicted segmentation results and visualize them alongside the ground truth masks and bounding boxes.

# 1. Imports and Setup

In [None]:
import os
import subprocess
import numpy as np
import nibabel as nib
import SimpleITK as sitk
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import torch

# 2.Extract Model - Checkpoint Conversion

This command uses the `extract_weights.py` script to convert a model checkpoint file into an extracted version compatible with MedSAM.

Adjust the input (`-from_pth`) and output (`-to_pth`) paths as needed.

In [None]:
!python /path/to/MedSAM/utils/extract_weights.py \
    -from_pth /path/to/MedSAM/work_dir/MedSAM-ViT-B-YYYYMMDD-HHMM/medsam_model_best.pth \
    -to_pth /path/to/MedSAM/work_dir/MedSAM-ViT-B-YYYYMMDD-HHMM/medsam_model_best_extracted.pth


  from_pth = torch.load(from_pth, map_location='cpu')
Weights are saved to /research/projects/Sahika/MedSAM/work_dir/MedSAM-ViT-B-20240925-2309/medsam_model_best_extracted.pth


In [None]:
model_path = "<MODEL_PATH>"                     # Path to the trained model, e.g., 'path/to/medsam_model_best_extracted.pth'
test_mri_dir = "<MRI_DIR>"                      # Path to the MRI test images, e.g., 'path/to/imagesTs'
ground_truth_dir = "<GROUND_TRUTH_DIR>"         # Path to the ground truth masks, e.g., 'path/to/labelsTs'
predicted_masks_dir = "<PREDICTED_MASKS_DIR>"   # Directory to save predicted masks, e.g., 'path/to/labelsTs_pred'

# 3. Helper Functions

Tum modeller 7 cm kup icin egitildiginden dolayi evaluation asamasinda ayni alani degerlendirmek amaciyla ground truth label klavuzunda roinin tespiti icin

The calculate_bounding_box_with_margin_3d function finds the smallest 3D bounding box that contains all non-zero elements in a 3D mask volume. A margin is added in the x and y dimensions to include some surrounding area, but the z-dimension range remains fixed.

In [None]:
def calculate_bounding_box_with_margin_3d(mask_data, margin=10):
    """
    Calculate a 3D bounding box around the non-zero mask regions.
    Adds a specified margin in the x and y dimensions only.
    
    Parameters:
        mask_data (np.ndarray): 3D mask array.
        margin (int): Margin to add around the bounding box in x and y directions.
        
    Returns:
        list or None: [x_min, y_min, z_min, x_max, y_max, z_max] if found, else None.
    """
    coords = np.nonzero(mask_data)
    if coords[0].size == 0:  # Empty mask
        return None

    z_min, z_max = np.min(coords[2]), np.max(coords[2])
    y_min, y_max = np.min(coords[0]), np.max(coords[0])
    x_min, x_max = np.min(coords[1]), np.max(coords[1])

    # Add margin to x and y dimensions only
    y_min = max(y_min - margin, 0)
    y_max = min(y_max + margin, mask_data.shape[0] - 1)
    x_min = max(x_min - margin, 0)
    x_max = min(x_max + margin, mask_data.shape[1] - 1)

    return [x_min, y_min, z_min, x_max, y_max, z_max]

Once we have a 3D bounding box, we  need to extract a 2D bounding box for a specific slice (z-index). The calculate_slice_bounding_box function uses the 3D coordinates but returns a 2D bounding box applicable for a single slice.

In [None]:
def calculate_slice_bounding_box(bbox_3d, slice_index):
    """
    Extract the 2D bounding box from a 3D bounding box for a given slice (z-index).
    
    Parameters:
        bbox_3d (list): [x_min, y_min, z_min, x_max, y_max, z_max].
        slice_index (int): The slice (z) index.
        
    Returns:
        list: [x_min, y_min, x_max, y_max] for the given slice.
    """
    x_min, y_min, z_min, x_max, y_max, z_max = bbox_3d
    return [x_min, y_min, x_max, y_max]


MRI slices are often single-channel (grayscale) images. The convert_to_3_channel function replicates the grayscale data into 3 channels so that we can treat it as a 3-channel image if needed (e.g., for visualization or models expecting multiple channels).

In [None]:
def convert_to_3_channel(grayscale_image):
    """
    Convert a 2D grayscale image to a 3-channel image by stacking it along a new axis.
    
    Parameters:
        grayscale_image (np.ndarray): 2D image.
        
    Returns:
        np.ndarray: 3D image with shape (H, W, 3).
    """
    return np.stack([grayscale_image] * 3, axis=-1)


The crop_image_with_bbox function crops a 2D image slice using a given bounding box. It ensures that the bounding box coordinates are within the image boundaries and returns the cropped image region.

In [None]:
def crop_image_with_bbox(img_slice, bbox):
    """
    Crop a 2D image slice based on the given bounding box.
    
    Parameters:
        img_slice (np.ndarray): 2D image slice.
        bbox (list): [x_min, y_min, x_max, y_max].
        
    Returns:
        np.ndarray: Cropped 2D image.
    """
    x_min, y_min, x_max, y_max = bbox
    x_min = int(max(x_min, 0))
    y_min = int(max(y_min, 0))
    x_max = int(min(x_max, img_slice.shape[1]))
    y_max = int(min(y_max, img_slice.shape[0]))

    print(f"Cropping with bbox: x_min={x_min}, y_min={y_min}, x_max={x_max}, y_max={y_max}")
    if x_min < x_max and y_min < y_max:
        cropped_img = img_slice[y_min:y_max, x_min:x_max]
        print(f"Cropped image size: {cropped_img.shape}")
        return cropped_img
    else:
        print("Invalid bounding box, returning original image.")
        return img_slice


The plot_comparison function helps visualize and compare the original image slice, ground truth mask, predicted mask, and the cropped image side-by-side. It also draws the bounding box onto the displayed images for reference.

In [None]:
def plot_comparison(img_slice_3ch, gt_slice, pred_mask, cropped_img, bbox):
    """
    Plot the original image, ground truth mask, predicted mask, and cropped image side-by-side.
    Overlays bounding boxes and masks for visual comparison.
    
    Parameters:
        img_slice_3ch (np.ndarray): 3-channel image slice.
        gt_slice (np.ndarray): Ground truth mask for the slice.
        pred_mask (np.ndarray): Predicted segmentation mask.
        cropped_img (np.ndarray): Cropped image portion.
        bbox (list): [x_min, y_min, x_max, y_max] bounding box for visualization.
    """
    plt.figure(figsize=(12, 6))

    # Original Image
    plt.subplot(1, 4, 1)
    plt.imshow(img_slice_3ch[:, :, 0], cmap='gray')
    plt.title('Original Image')
    plt.axis('off')
    if bbox:
        x_min, y_min, x_max, y_max = bbox
        ax = plt.gca()
        rect = patches.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min, 
                                 linewidth=2, edgecolor='yellow', facecolor='none')
        ax.add_patch(rect)

    # Ground Truth Mask Overlay
    plt.subplot(1, 4, 2)
    plt.imshow(img_slice_3ch[:, :, 0], cmap='gray', alpha=0.7)
    plt.imshow(gt_slice, cmap='Reds', alpha=0.3)
    plt.title('Ground Truth Mask Overlay')
    plt.axis('off')
    if bbox:
        ax = plt.gca()
        rect = patches.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min, 
                                 linewidth=0.5, edgecolor='yellow', facecolor='none')
        ax.add_patch(rect)

    # Predicted Mask Overlay
    plt.subplot(1, 4, 3)
    plt.imshow(img_slice_3ch[:, :, 0], cmap='gray', alpha=0.7)
    plt.imshow(pred_mask, cmap='Blues', alpha=0.3)
    plt.title('Predicted Mask Overlay')
    plt.axis('off')
    if bbox:
        ax = plt.gca()
        rect = patches.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min, 
                                 linewidth=0.5, edgecolor='yellow', facecolor='none')
        ax.add_patch(rect)

    # Cropped Image
    plt.subplot(1, 4, 4)
    plt.imshow(cropped_img, cmap='gray')
    plt.title('Cropped Original Image')
    plt.axis('off')

    plt.tight_layout()
    plt.show()


# 4. Main Processing Loop

Process Description:

1. Iterate over files in image_dir looking for MRI volumes (_0000.nii.gz pattern).
2. For each volume, load its corresponding label file.
3. Compute a 3D bounding box around the non-zero regions in the label.
4. Iterate over slices within the bounding box (z_min to z_max).
5. For each slice:
    1. Convert the slice to a 3-channel image.
    2. Compute the slice-level bounding box.
    3. Use the model (bbox_prompt) to predict a mask within this bounding box.
    4. Save the predicted mask as a NIfTI file.
    5. Crop the image using the bounding box.
    6. Plot the original slice, ground truth, prediction, and cropped image for visual inspection.
    7. Stop after a certain number of slices to limit processing.

In [None]:
# User-defined paths and model setup
image_dir = "path_to_images"
label_dir = "path_to_labels"
output_dir = "path_to_output"
bbox_prompt = ...  # Your model or object implementing _set_image() and _infer()

count = 0
stop_processing = False

for image_file in os.listdir(image_dir):
    print(image_file)
    if stop_processing:
        break

    if image_file.endswith('_0000.nii.gz'):
        image_path = os.path.join(image_dir, image_file)
        label_name = image_file.replace('_0000.nii.gz', '.nii.gz')
        label_path = os.path.join(label_dir, label_name)

        label_img = nib.load(label_path)
        label_data = label_img.get_fdata()

        mri_img = nib.load(image_path)
        mri_data = mri_img.get_fdata()

        bounding_box_3d = calculate_bounding_box_with_margin_3d(label_data, margin=1)
        if bounding_box_3d is None:
            print(f"No bounding box for {image_file}. Skipping.")
            continue

        # Initialize an empty array to store the predicted masks
        pred_mask_volume = np.zeros_like(mri_data)

        # Loop over each slice in the 3D image within the bounding box range
        x_min, y_min, z_min, x_max, y_max, z_max = bounding_box_3d
        for z in range(z_min, z_max + 1):
            slice_data = label_data[:, :, z]
            if np.sum(slice_data) == 0:
                # If no mask is found, continue and leave slice as zeros
                continue

            # Extract the current 2D slice
            img_slice = mri_data[:, :, z]

            # Convert to 3-channel image for the model
            img_slice_3ch = convert_to_3_channel(img_slice)

            # Set the image in the BboxPromptDemo
            bbox_prompt._set_image(img_slice_3ch)

            # Get prediction from MedSAM
            with torch.no_grad():
                predicted_mask = bbox_prompt._infer([x_min, y_min, x_max, y_max])

            # Add the predicted mask to the volume
            pred_mask_volume[:, :, z] = predicted_mask.astype(np.uint8)

        # Save the predicted mask volume as a NIfTI file
        pred_nifti = nib.Nifti1Image(pred_mask_volume, mri_img.affine)
        output_path = os.path.join(output_dir, f"{image_file.replace('_0000.nii', '.nii')}")
        nib.save(pred_nifti, output_path)

        print(f"Saved predicted mask for {image_file} to {output_path}")

print('Segmentation complete.')

# 5.Evaluate Model Performance

This section uses the `nnUNetv2_result.py` script to evaluate the model's segmentation performance. The script calculates metrics like Dice score and groups tumors by size using specified thresholds.


In [None]:
# Define placeholders for paths
csv_path = "<CSV_PATH>"      # Path to the dataset CSV file, e.g., 'path/to/dataset.csv'
                                # This CSV file should contain all dataset information, 
                                # including at least the following columns:
                                # - 'file_name': Names or paths of the MRI files.
                                # - 'tumor_volume': Corresponding tumor volume measurements.

thresholds = "[200,400]"  # Thresholds for grouping tumors by size:
                          # - <200: Small tumors
                          # - 200-400: Medium tumors
                          # - >400: Large tumors

# Function to call the nnUNetv2_result command dynamically
def run_nnUNetv2_result(gt, pred, csv, thresholds):
    """
    Run the nnUNetv2_result script with specified parameters.

    Args:
        gt (str): Path to the ground truth masks.
        pred (str): Path to the predicted masks.
        csv (str): Path to the CSV file.
        thresholds (str): Threshold values in the format '[200,400]'.
    """
    command = f'python3 nnUNetv2_result.py -gt {gt} -pred {pred} -csv {csv} -th {thresholds}'
    try:
        # Run the shell command
        result = subprocess.run(command, shell=True, check=True, capture_output=True, text=True)
        # Print the output
        print(result.stdout)
    except subprocess.CalledProcessError as e:
        print(f"Error occurred: {e.stderr}")

# Call the function for the test set
print("Running nnUNetv2_result for the test set...")
run_nnUNetv2_result(ground_truth_dir, predicted_masks_dir, csv_path, thresholds)


# 6. Visualize Results

We visualize three key slices from the MRI:
1. First Quarter Slice,
2. Middle Slice,
3. Third Quarter Slice.

For each slice, the following are shown:
1. Original MRI Image.
2. Ground Truth Overlay.
3. Predicted Mask Overlay.
4. Combined Ground Truth and Prediction Overlay.

The visualizations are presented in a grid layout for easy comparison.

In [None]:
# Function to load a NIfTI image
def load_nifti_image(filepath):
    image = sitk.ReadImage(filepath)
    return sitk.GetArrayFromImage(image)

# Function to apply the same validation transforms to both ground truth and predicted images
def apply_test_transforms(image_file, label_file, test_transforms):
    val_data = {"image": image_file, "label": label_file}
    transformed = test_transforms(val_data)
    return transformed["image"], transformed["label"]

def visualize_slices_grid(image, ground_truth, prediction):
    """
    Visualize slices in a grid:
    1st row: First quarter slice
    2nd row: Middle slice
    3rd row: Third quarter slice

    Each row contains:
    1. Original Image
    2. Ground Truth Overlay
    3. Predicted Mask Overlay
    4. Combined Overlay
    """
    # Calculate slice indices dynamically
    num_slices = image.shape[0]
    slice_indices = [num_slices // 4, num_slices // 2, 3 * num_slices // 4]

    plt.figure(figsize=(24, 18))  # Adjust figure size for three rows

    for row_idx, slice_idx in enumerate(slice_indices):
        # Extract slices
        if len(image.shape) == 3:  # 3D MRI
            image_slice = image[slice_idx, :, :]
            gt_slice = ground_truth[slice_idx, :, :]
            pred_slice = prediction[slice_idx, :, :]
     
        else:
            raise ValueError(f"Unexpected image shape: {image.shape}")

        # Original Image
        plt.subplot(3, 4, row_idx * 4 + 1)
        plt.imshow(image_slice, cmap='gray')
        plt.title(f'Original Image - Slice {slice_idx}')
        plt.axis('off')

        # Ground Truth Overlay
        plt.subplot(3, 4, row_idx * 4 + 2)
        plt.imshow(image_slice, cmap='gray', alpha=0.8)
        plt.imshow(gt_slice, cmap='Reds', alpha=0.2)
        plt.title(f'Ground Truth - Slice {slice_idx}')
        plt.axis('off')

        # Predicted Mask Overlay
        plt.subplot(3, 4, row_idx * 4 + 3)
        plt.imshow(image_slice, cmap='gray', alpha=0.8)
        plt.imshow(pred_slice, cmap='Blues', alpha=0.2)
        plt.title(f'Predicted Mask - Slice {slice_idx}')
        plt.axis('off')

        # Combined Overlay (Ground Truth + Predicted)
        plt.subplot(3, 4, row_idx * 4 + 4)
        plt.imshow(image_slice, cmap='gray', alpha=0.8)
        plt.imshow(gt_slice, cmap='Reds', alpha=0.2)
        plt.imshow(pred_slice, cmap='Blues', alpha=0.2)
        plt.title(f'Combined Overlay - Slice {slice_idx}')
        plt.axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
basename = next((f for f in os.listdir(test_mri_dir) if f.endswith(".nii.gz")), None)
if not basename:
    raise FileNotFoundError("No NIfTI files found in the MRI directory.")

# Generate file paths
image_file_path = os.path.join(test_mri_dir, basename)
gt_file_path = os.path.join(ground_truth_dir, basename.replace("_0000", ""))
pred_file_path = os.path.join(predicted_masks_dir, basename.replace("_0000", ""))

# Load the images, ground truth, and predicted masks as numpy arrays
mri_nifti = load_nifti_image(image_file_path)
ground_truth_nifti = load_nifti_image(gt_file_path.replace('labelsTs', 'labelsTs_transformed'))
predicted_nifti = load_nifti_image(pred_file_path)

print(f"Image shape: {mri_nifti.shape}")
print(f"Ground Truth shape: {ground_truth_nifti.shape}")
print(f"Prediction shape: {predicted_nifti.shape}")

# Apply the validation transforms to both the ground truth and predicted data
img_transformed, gt_transformed = apply_test_transforms(image_file_path, gt_file_path, test_transforms)
_, pred_transformed = apply_test_transforms(image_file_path, pred_file_path, test_transforms)

# Ensure that the images and masks have the same shape
if gt_transformed.shape != pred_transformed.shape:
    print(f"Shape mismatch: GT shape: {gt_transformed.shape}, Pred shape: {pred_transformed.shape}")
else:
    # Plot the individual slices (original image, ground truth, predicted) with original image as background
    visualize_slices_grid(mri_nifti, ground_truth_nifti, predicted_nifti)