# SAM3 Image Segmentation - Interactive Notebook

This notebook provides an interactive interface for SAM3 image segmentation using text prompts.

Features:
- Upload images or load from URL
- Text prompt-based segmentation
- Adjustable score threshold
- Limit number of instances displayed
- Colored instance overlays

## Setup

Enable interactive matplotlib widgets for better display:

In [None]:
pip install ipympl

In [7]:
pip install -e c:\dev\sam3\sam3

Obtaining file:///C:/dev/sam3/sam3
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Checking if build backend supports build_editable: started
  Checking if build backend supports build_editable: finished with status 'done'
  Getting requirements to build editable: started
  Getting requirements to build editable: finished with status 'done'
  Preparing editable metadata (pyproject.toml): started
  Preparing editable metadata (pyproject.toml): finished with status 'done'
Collecting numpy==1.26 (from sam3==0.1.0)
  Using cached numpy-1.26.0-cp311-cp311-win_amd64.whl.metadata (61 kB)
Using cached numpy-1.26.0-cp311-cp311-win_amd64.whl (15.8 MB)
Building wheels for collected packages: sam3
  Building editable for sam3 (pyproject.toml): started
  Building editable for sam3 (pyproject.toml): finished with status 'done'
  Created wheel for sam3: filename=sam3-0.1.0-0.editable-py3-none-any.whl size=15274 sha256=3b5557b47c00580e790a94473ee

  You can safely remove it manually.
  You can safely remove it manually.
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
buildingregulariser 0.2.4 requires numpy>=2.0.0, but you have numpy 1.26.0 which is incompatible.
opencv-python-headless 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= "3.9", but you have numpy 1.26.0 which is incompatible.


Restart Kernel before running this

In [None]:
%matplotlib widget

In [1]:
import sam3
print("sam3 path:", sam3.__file__)

sam3 path: C:\dev\sam3\sam3\sam3\__init__.py


## Import Libraries

In [2]:
import sys
import os

# Add the parent directory to the path to use local sam3 instead of installed package
notebook_dir = os.path.dirname(os.path.abspath('__file__'))
sam3_root = os.path.abspath(os.path.join(notebook_dir, '..', '..'))
if sam3_root not in sys.path:
    sys.path.insert(0, sam3_root)

import io
import numpy as np
from PIL import Image
import torch
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
import requests

from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor

## Load SAM3 Model

This may take a moment on first run as it downloads the model weights.

In [3]:
print("[SAM3] Building image model...")
model = build_sam3_image_model()
processor = Sam3Processor(model)
print("[SAM3] Image model ready.")

[SAM3] Building image model...
[SAM3] Image model ready.


## Helper Functions

In [4]:
def random_color_for_idx(idx: int) -> np.ndarray:
    """Generate deterministic pseudo-random color per instance index."""
    rng = np.random.RandomState(idx + 12345)
    return rng.randint(0, 255, size=3, dtype=np.uint8)


def overlay_instance_masks(
    image: Image.Image,
    masks: np.ndarray,
    scores: np.ndarray,
    score_thresh: float = 0.0,
    max_instances: int | None = None,
) -> Image.Image:
    """Overlay multiple instance masks on top of an RGB image."""
    if masks is None or len(masks) == 0:
        return image

    # Filter by score
    keep = scores >= score_thresh
    if keep.sum() == 0:
        return image

    masks = masks[keep]
    scores = scores[keep]

    # Sort by score (highest first)
    order = np.argsort(-scores)
    masks = masks[order]
    scores = scores[order]

    if max_instances is not None and max_instances > 0:
        masks = masks[:max_instances]
        scores = scores[:max_instances]

    img_np = np.array(image).astype(np.uint8)
    overlay = img_np.copy()
    target_h, target_w = img_np.shape[:2]

    for i, m in enumerate(masks):
        # Binary mask; SAM3 may give float masks
        # First squeeze out any singleton dimensions
        m = np.squeeze(m)

        # Handle case where mask might still be 3D (e.g., [1, H, W] or [H, W, 1])
        while m.ndim > 2:
            m = m.squeeze()

        m_bin = m > 0.5

        # Check if mask dimensions match the image
        if m_bin.shape != (target_h, target_w):
            # Resize mask to match image dimensions using PIL
            mask_img = Image.fromarray((m_bin * 255).astype(np.uint8), mode='L')
            mask_img = mask_img.resize((target_w, target_h), Image.BILINEAR)
            m_bin = np.array(mask_img) > 127

        if not m_bin.any():
            continue

        color = random_color_for_idx(i)
        alpha = 0.6

        # Blend color where mask is true
        overlay[m_bin] = (
            (1.0 - alpha) * overlay[m_bin].astype(np.float32)
            + alpha * color.astype(np.float32)
        ).astype(np.uint8)

    return Image.fromarray(overlay)


def run_sam3_segmentation(
    image: Image.Image,
    prompt: str,
    score_thresh: float = 0.0,
    max_instances: int | None = None,
):
    """Run SAM3 segmentation on an image with a text prompt."""
    if image is None or prompt is None or prompt.strip() == "":
        return None, "Please provide an image and a non-empty text prompt."

    # Ensure RGB
    if image.mode != "RGB":
        image = image.convert("RGB")

    # Initialize state with the image
    state = processor.set_image(image)

    # Text prompt
    output = processor.set_text_prompt(
        state=state,
        prompt=prompt.strip(),
    )

    # Extract masks and scores
    masks = output.get("masks", None)
    scores = output.get("scores", None)

    if masks is None or scores is None:
        return image, "Model returned no masks or scores."

    # Convert tensors to numpy
    if torch.is_tensor(masks):
        masks = masks.detach().cpu().numpy()
    if torch.is_tensor(scores):
        scores = scores.detach().cpu().numpy()

    # Overlay all instances
    overlaid = overlay_instance_masks(
        image,
        masks,
        scores,
        score_thresh=score_thresh,
        max_instances=max_instances,
    )

    # Build summary
    num_total = len(scores)
    num_kept = int((scores >= score_thresh).sum())
    summary = {
        "total": num_total,
        "kept": num_kept,
        "threshold": score_thresh,
        "max_shown": max_instances if max_instances else num_kept,
    }

    return overlaid, summary

## Interactive Widget

Run the cell below to launch the interactive segmentation interface!

In [None]:
class ImageSegmentationWidget:
    """Interactive widget for SAM3 image segmentation."""
    
    def __init__(self, processor):
        self.processor = processor
        self.current_image = None
        self.current_result = None
        
        self._setup_ui()
    
    def _setup_ui(self):
        """Create the UI components."""
        # Image upload widget
        self.upload_widget = widgets.FileUpload(
            accept='image/*',
            multiple=False,
            description='Upload Image'
        )
        self.upload_widget.observe(self._on_image_upload, names='value')
        
        # URL input
        self.url_input = widgets.Text(
            placeholder='Or enter image URL',
            description='Image URL:',
            layout=widgets.Layout(width='500px')
        )
        self.url_button = widgets.Button(
            description='Load URL',
            button_style='info'
        )
        self.url_button.on_click(self._on_load_url)
        
        # Text prompt input
        self.prompt_input = widgets.Text(
            placeholder='e.g. "a person", "a dog", "white truck"',
            description='Text Prompt:',
            layout=widgets.Layout(width='500px')
        )
        
        # Score threshold slider
        self.score_slider = widgets.FloatSlider(
            value=0.0,
            min=0.0,
            max=1.0,
            step=0.01,
            description='Score Threshold:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='500px')
        )
        
        # Max instances input
        self.max_instances_input = widgets.IntText(
            value=0,
            description='Max Instances (0=all):',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='250px')
        )
        
        # Run button
        self.run_button = widgets.Button(
            description='Run SAM3 Segmentation',
            button_style='success',
            icon='play'
        )
        self.run_button.on_click(self._on_run_segmentation)
        
        # Status label
        self.status_label = widgets.HTML(
            value='<b>Upload an image to begin</b>'
        )
        
        # Output area
        self.output_area = widgets.Output()
        
        # Layout
        self.container = widgets.VBox([
            widgets.HTML('<h2>SAM3 Image Segmentation - Interactive</h2>'),
            widgets.HTML('<hr>'),
            widgets.HTML('<h3>1. Load Image</h3>'),
            self.upload_widget,
            widgets.HBox([self.url_input, self.url_button]),
            widgets.HTML('<h3>2. Configure Segmentation</h3>'),
            self.prompt_input,
            self.score_slider,
            self.max_instances_input,
            self.run_button,
            widgets.HTML('<hr>'),
            self.status_label,
            self.output_area,
        ])
    
    def _on_image_upload(self, change):
        """Handle image upload."""
        if change['new']:
            uploaded_file = change['new'][0]
            self.current_image = Image.open(io.BytesIO(uploaded_file['content'])).convert('RGB')
            self.status_label.value = f'<b style="color: green;">Image loaded: {self.current_image.size[0]}x{self.current_image.size[1]} pixels</b>'
            self._display_image(self.current_image)
    
    def _on_load_url(self, button):
        """Handle loading image from URL."""
        url = self.url_input.value.strip()
        if not url:
            self.status_label.value = '<b style="color: red;">Please enter a URL</b>'
            return
        
        self.status_label.value = '<b style="color: blue;">Loading image from URL...</b>'
        
        try:
            response = requests.get(url, timeout=10)
            response.raise_for_status()
            self.current_image = Image.open(io.BytesIO(response.content)).convert('RGB')
            self.status_label.value = f'<b style="color: green;">Image loaded: {self.current_image.size[0]}x{self.current_image.size[1]} pixels</b>'
            self._display_image(self.current_image)
        except Exception as e:
            self.status_label.value = f'<b style="color: red;">Error loading image: {str(e)}</b>'
    
    def _on_run_segmentation(self, button):
        """Handle running segmentation."""
        if self.current_image is None:
            self.status_label.value = '<b style="color: red;">Please load an image first</b>'
            return
        
        prompt = self.prompt_input.value.strip()
        if not prompt:
            self.status_label.value = '<b style="color: red;">Please enter a text prompt</b>'
            return
        
        self.status_label.value = f'<b style="color: blue;">Segmenting with prompt: "{prompt}"...</b>'
        
        try:
            max_inst = self.max_instances_input.value if self.max_instances_input.value > 0 else None
            
            result_img, summary = run_sam3_segmentation(
                self.current_image,
                prompt,
                score_thresh=self.score_slider.value,
                max_instances=max_inst,
            )
            
            self.current_result = result_img
            
            if isinstance(summary, dict):
                status_html = f'''
                <b style="color: green;">Segmentation Complete!</b><br>
                <b>Prompt:</b> "{prompt}"<br>
                <b>Total instances found:</b> {summary['total']}<br>
                <b>Above threshold ({summary['threshold']:.2f}):</b> {summary['kept']}<br>
                <b>Instances shown:</b> {summary['max_shown']}
                '''
                self.status_label.value = status_html
            else:
                self.status_label.value = f'<b style="color: orange;">{summary}</b>'
            
            self._display_image(result_img)
            
        except Exception as e:
            self.status_label.value = f'<b style="color: red;">Error: {str(e)}</b>'
    
    def _display_image(self, image):
        """Display an image in the output area."""
        with self.output_area:
            clear_output(wait=True)
            fig, ax = plt.subplots(figsize=(12, 8))
            ax.imshow(image)
            ax.axis('off')
            plt.tight_layout()
            plt.show()
    
    def display(self):
        """Display the widget."""
        display(self.container)


# Create and display the widget
widget = ImageSegmentationWidget(processor)
widget.display()

VBox(children=(HTML(value='<h2>SAM3 Image Segmentation - Interactive</h2>'), HTML(value='<hr>'), HTML(value='<â€¦