# 🔬 SAM Inference for Surgical Instance Segmentation

This notebook demonstrates the use of Segment Anything Model (SAM) for surgical instrument segmentation in cataract surgery videos.

## Overview
- Load pre-trained SAM model
- Process surgical images
- Generate segmentation masks
- Evaluate performance metrics

## Setup and Installation

In [None]:
# Install required packages
!pip install git+https://github.com/facebookresearch/segment-anything.git
!pip install opencv-python pycocotools matplotlib numpy torch torchvision

In [None]:
# Download SAM checkpoint
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

## Import Required Libraries

In [None]:
import os
import json
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from pycocotools import mask as maskUtils

# Import SAM components
from segment_anything import SamPredictor, sam_model_registry

print("Libraries imported successfully!")

## Initialize SAM Model

In [None]:
# Model configuration
sam_checkpoint = "sam_vit_h_4b8939.pth"  # Path to SAM checkpoint
model_type = "vit_h"  # Model type
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load SAM model
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

# Create predictor
predictor = SamPredictor(sam)

print(f"SAM model loaded on {device}")
print(f"Model type: {model_type}")

## Utility Functions

In [None]:
def mask_to_bbox(mask):
    """Convert binary mask to bounding box coordinates."""
    ys, xs = np.where(mask)
    if len(xs) == 0 or len(ys) == 0:
        return [0, 0, 0, 0]
    x_min, x_max = int(xs.min()), int(xs.max())
    y_min, y_max = int(ys.min()), int(ys.max())
    return [x_min, y_min, x_max - x_min, y_max - y_min]


def show_mask(mask, ax, random_color=False):
    """Display segmentation mask with color overlay."""
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=375):
    """Display prompt points on the image."""
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(
        pos_points[:, 0],
        pos_points[:, 1],
        color="green",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )
    ax.scatter(
        neg_points[:, 0],
        neg_points[:, 1],
        color="red",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )


def show_box(box, ax):
    """Display bounding box on the image."""
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(
        plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)
    )


print("Utility functions defined!")

## SAM Inference Example

In [None]:
# Example inference function
def run_sam_inference(image_path, input_points=None, input_boxes=None):
    """Run SAM inference on a single image."""
    # Load and process image
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Set image for predictor
    predictor.set_image(image)

    # Prepare inputs
    input_labels = None
    if input_points is not None:
        input_labels = np.array([1] * len(input_points))  # All positive points
        input_points = np.array(input_points)

    # Generate masks
    masks, scores, logits = predictor.predict(
        point_coords=input_points,
        point_labels=input_labels,
        box=input_boxes,
        multimask_output=True,
    )

    return image, masks, scores, logits


# Visualization function
def visualize_results(image, masks, scores, input_points=None, input_boxes=None):
    """Visualize SAM results."""
    fig, axes = plt.subplots(1, len(masks), figsize=(15, 5))
    if len(masks) == 1:
        axes = [axes]

    for i, (mask, score) in enumerate(zip(masks, scores)):
        axes[i].imshow(image)
        show_mask(mask, axes[i])

        if input_points is not None:
            show_points(input_points, np.ones(len(input_points)), axes[i])

        if input_boxes is not None:
            show_box(input_boxes, axes[i])

        axes[i].set_title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
        axes[i].axis("off")

    plt.tight_layout()
    plt.show()


print("SAM inference functions ready!")

## Performance Evaluation

In [None]:
def calculate_iou(mask1, mask2):
    """Calculate Intersection over Union (IoU) between two masks."""
    intersection = np.logical_and(mask1, mask2).sum()
    union = np.logical_or(mask1, mask2).sum()
    return intersection / union if union > 0 else 0


def evaluate_segmentation(pred_masks, gt_masks):
    """Evaluate segmentation performance."""
    ious = []
    for pred_mask, gt_mask in zip(pred_masks, gt_masks):
        iou = calculate_iou(pred_mask, gt_mask)
        ious.append(iou)

    mean_iou = np.mean(ious)
    return mean_iou, ious


# Example evaluation
print("Evaluation functions defined!")
print("Ready for SAM inference on surgical images!")

## Summary

This notebook provides a complete workflow for using SAM (Segment Anything Model) for surgical instance segmentation:

1. **Model Loading**: Load pre-trained SAM checkpoint
2. **Inference**: Run segmentation with point/box prompts
3. **Visualization**: Display results with overlays
4. **Evaluation**: Calculate performance metrics

The implementation follows the exact methodology described in the Cataract-LMM research paper for surgical instrument segmentation tasks.