In [None]:
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Vertex AI Model Garden - SAM 3 (Segment Anything Model 3)
<table><tbody><tr>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/notebooks/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/vertex-ai-samples/main/notebooks/community/model_garden/model_garden_pytorch_sam3.ipynb">
      <img alt="Workbench logo" src="https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32" width="32px"><br> Run in Workbench
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fvertex-ai-samples%2Fmain%2Fnotebooks%2Fcommunity%2Fmodel_garden%2Fmodel_garden_pytorch_sam3.ipynb">
      <img alt="Google Cloud Colab Enterprise logo" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" width="32px"><br> Run in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_pytorch_sam3.ipynb">
      <img alt="GitHub logo" src="https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png" width="32px"><br> View on GitHub
    </a>
  </td>
</tr></tbody></table>

## Overview

This notebook demonstrates deploying **SAM 3** on Vertex AI for:
- **Image Segmentation** (text-prompted)
- **Point-Click Segmentation** (coordinate-based)
- **Video Segmentation** (object tracking)

### Costs

This tutorial uses billable components of Google Cloud:

* Vertex AI
* Cloud Storage

Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing), [Cloud Storage pricing](https://cloud.google.com/storage/pricing), and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage.

### File a bug

File a bug on [GitHub](https://github.com/GoogleCloudPlatform/vertex-ai-samples/issues/new) if you encounter any issue with the notebook.

## Setup

In [None]:
%pip install --upgrade --quiet 'google-cloud-aiplatform>=1.106.0' 'google-cloud-storage' 'pycocotools' 'numpy<2.0' 'opencv-python-headless' 'matplotlib' 'Pillow==11.3.0'

In [None]:
import os
import sys

if "google.colab" in sys.modules:
    from google.colab import auth

    auth.authenticate_user()

In [None]:
import base64
import io
import re
import subprocess
import tempfile
import uuid

import cv2
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import requests
from google.cloud import aiplatform, storage
from PIL import Image, ImageDraw
from pycocotools import mask as mask_utils

In [None]:
# @markdown 1. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).

# @markdown 2. **[Optional]** [Create a Cloud Storage bucket](https://cloud.google.com/storage/docs/creating-buckets) for storing experiment outputs. Set the BUCKET_URI for the experiment environment. The specified Cloud Storage bucket (`BUCKET_URI`) should be located in the same region as where the notebook was launched. Note that a multi-region bucket (eg. "us") is not considered a match for a single region covered by the multi-region range (eg. "us-central1"). If not set, a unique GCS bucket will be created instead.

BUCKET_URI = "gs://"  # @param {type:"string"}

# @markdown 3. **[Optional]** Set region. If not set, the region will be set automatically according to Colab Enterprise environment.

REGION = ""  # @param {type:"string"}


# Get the default cloud project id.
PROJECT_ID = os.environ["GOOGLE_CLOUD_PROJECT"]

# Get the default region for launching jobs.
if not REGION:
    if not os.environ.get("GOOGLE_CLOUD_REGION"):
        raise ValueError(
            "REGION must be set. See"
            " https://cloud.google.com/vertex-ai/docs/general/locations for"
            " available cloud locations."
        )
    REGION = os.environ["GOOGLE_CLOUD_REGION"]

# Enable the Vertex AI API and Compute Engine API, if not already.
print("Enabling Vertex AI API and Compute Engine API.")
! gcloud services enable aiplatform.googleapis.com compute.googleapis.com

# Initialize Vertex AI API.
print("Initializing Vertex AI API.")
aiplatform.init(project=PROJECT_ID, location=REGION)
! gcloud config set project $PROJECT_ID

models, endpoints = {}, {}

In [None]:
# GCS bucket for large images/videos (required for files > 1.1MB)
GCS_BUCKET = BUCKET_URI if BUCKET_URI and BUCKET_URI != "gs://" else None

# Hugging Face token (required for gated model facebook/sam3)
HF_TOKEN = ""  # @param {type:"string"}

In [None]:
# SAM3 Configuration
SAM3_DOCKER_URI = "us-docker.pkg.dev/deeplearning-platform-release/vertex-model-garden/pytorch-inference.cu125.0-4.ubuntu2204.py310"
DEFAULT_MASK_BLUR_SIGMA = 3.5

# GPU Configuration: (accelerator_type, gpu_count) -> machine_type
GPU_MACHINE_TYPE_MAP = {
    ("NVIDIA_L4", 1): "g2-standard-12",
    ("NVIDIA_L4", 2): "g2-standard-24",
    ("NVIDIA_L4", 4): "g2-standard-48",
    ("NVIDIA_L4", 8): "g2-standard-96",
    ("NVIDIA_H100_80GB", 1): "a3-highgpu-1g",
    ("NVIDIA_H100_80GB", 2): "a3-highgpu-2g",
    ("NVIDIA_H100_80GB", 4): "a3-highgpu-4g",
    ("NVIDIA_H100_80GB", 8): "a3-highgpu-8g",
}


def deploy_sam3_model(
    model_id: str,
    accelerator_type: str = "NVIDIA_L4",
    accelerator_count: int = 1,
    use_dedicated_endpoint: bool = True,
    hf_token: str = None,
) -> tuple:
    """Deploy SAM3 model to Vertex AI endpoint."""
    machine_type = GPU_MACHINE_TYPE_MAP[(accelerator_type, accelerator_count)]

    endpoint = aiplatform.Endpoint.create(
        display_name="sam3-endpoint-notebook",
        dedicated_endpoint_enabled=use_dedicated_endpoint,
    )

    env_vars = {
        "MODEL_ID": model_id,
        "TASK": "mask-generation",
        "DEPLOY_SOURCE": "notebook",
    }
    if hf_token:
        env_vars["HF_TOKEN"] = hf_token

    model = aiplatform.Model.upload(
        display_name="sam3-notebook",
        serving_container_image_uri=SAM3_DOCKER_URI,
        serving_container_ports=[8080],
        serving_container_predict_route="/predict",
        serving_container_health_route="/health",
        serving_container_environment_variables=env_vars,
        serving_container_shared_memory_size_mb=(16 * 1024),
        serving_container_deployment_timeout=7200,
    )

    model.deploy(
        endpoint=endpoint,
        machine_type=machine_type,
        accelerator_type=accelerator_type,
        accelerator_count=accelerator_count,
        min_replica_count=1,
        max_replica_count=1,
        deploy_request_timeout=1800,
    )
    return model, endpoint


def _get_auth_headers():
    token = (
        subprocess.check_output(["gcloud", "auth", "print-access-token"])
        .decode()
        .strip()
    )
    return {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}


def _get_dedicated_dns(endpoint):
    try:
        if (
            hasattr(endpoint, "gca_resource")
            and endpoint.gca_resource.dedicated_endpoint_dns
        ):
            return endpoint.gca_resource.dedicated_endpoint_dns
    except:
        pass
    return None


def call_sam3_endpoint(payload, endpoint, timeout=180):
    """Call SAM3 Vertex AI endpoint with automatic dedicated DNS resolution."""
    endpoint_id = endpoint.name.split("/")[-1]
    dns = _get_dedicated_dns(endpoint)
    host = dns if dns else f"{REGION}-aiplatform.googleapis.com"
    url = f"https://{host}/v1/projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint_id}:predict"

    response = requests.post(
        url, json=payload, headers=_get_auth_headers(), timeout=timeout
    )
    if response.status_code == 400 and "dedicated domain name" in response.text:
        dns = re.search(r"dedicated domain name '([^']+)'", response.text).group(1)
        url = f"https://{dns}/v1/projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint_id}:predict"
        response = requests.post(
            url, json=payload, headers=_get_auth_headers(), timeout=timeout
        )

    if response.status_code != 200:
        raise RuntimeError(
            f"Vertex AI error {response.status_code}: {response.text[:500]}"
        )
    return response.json()


def decode_rle_mask(rle):
    counts = (
        rle["counts"].encode("utf-8")
        if isinstance(rle["counts"], str)
        else rle["counts"]
    )
    return mask_utils.decode({"size": rle["size"], "counts": counts}).astype(np.uint8)


def image_to_base64(img):
    buf = io.BytesIO()
    img.save(buf, format="PNG")
    return base64.b64encode(buf.getvalue()).decode()


def load_image(path_or_url):
    if path_or_url.startswith(("http://", "https://")):
        return Image.open(
            io.BytesIO(requests.get(path_or_url, timeout=30).content)
        ).convert("RGB")
    return Image.open(path_or_url).convert("RGB")


def upload_to_gcs(data, gcs_uri, prefix, content_type, is_file=False):
    """Upload data or file to GCS. Returns gs:// URI."""
    bucket_name, path = (
        gcs_uri[5:].split("/", 1) if "/" in gcs_uri[5:] else (gcs_uri[5:], "")
    )
    ext = ".mp4" if "video" in content_type else ".png"
    blob_path = f"{path}/{prefix}-{uuid.uuid4().hex[:8]}{ext}".lstrip("/")

    blob = storage.Client().bucket(bucket_name).blob(blob_path)
    if is_file:
        blob.upload_from_filename(data, content_type=content_type)
    else:
        blob.upload_from_string(data, content_type=content_type)
    return f"gs://{bucket_name}/{blob_path}"


def delete_gcs_file(gcs_uri):
    try:
        path = gcs_uri[5:]
        bucket_name, blob_path = path.split("/", 1)
        storage.Client().bucket(bucket_name).blob(blob_path).delete()
    except:
        pass


def apply_mask_overlay(image, masks, opacity=0.5):
    """Overlay colored masks on image."""
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    image = image.convert("RGBA")
    if not masks:
        return image.convert("RGB")

    cmap = matplotlib.colormaps["rainbow"].resampled(len(masks))
    composite = Image.new("RGBA", image.size, (0, 0, 0, 0))

    for i, mask in enumerate(masks):
        color = tuple(int(c * 255) for c in cmap(i)[:3])
        mask_img = Image.fromarray((mask * 255).astype(np.uint8))
        if mask_img.size != image.size:
            mask_img = mask_img.resize(image.size, Image.NEAREST)
        fill = Image.new("RGBA", image.size, color + (0,))
        fill.putalpha(mask_img.point(lambda v: int(v * opacity) if v > 0 else 0))
        composite = Image.alpha_composite(composite, fill)

    return Image.alpha_composite(image, composite).convert("RGB")


def draw_points(image, points):
    """Draw red circles at click points."""
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    img = image.copy()
    draw = ImageDraw.Draw(img)
    for x, y in points:
        draw.ellipse((x - 8, y - 8, x + 8, y + 8), fill="red", outline="white", width=4)
    return img


def segment_image(
    image_path_or_url,
    text_prompt,
    endpoint,
    gcs_bucket=None,
    blur_sigma=DEFAULT_MASK_BLUR_SIGMA,
):
    """Text-prompted image segmentation. Returns (original, overlay, masks)."""
    img = load_image(image_path_or_url)
    img_b64 = image_to_base64(img)
    gcs_uri = None

    # Upload large images to GCS
    if len(img_b64) * 0.75 > 1.1 * 1024 * 1024:
        if not gcs_bucket:
            raise RuntimeError("Image too large. Set GCS_BUCKET for large images.")
        buf = io.BytesIO()
        img.save(buf, format="PNG")
        gcs_uri = upload_to_gcs(buf.getvalue(), gcs_bucket, "sam3-img", "image/png")
        payload = {
            "instances": [
                {"image": gcs_uri, "text": text_prompt, "mask_blur_sigma": blur_sigma}
            ],
            "parameters": {"mask_format": "rle"},
        }
    else:
        payload = {
            "instances": [
                {"image": img_b64, "text": text_prompt, "mask_blur_sigma": blur_sigma}
            ],
            "parameters": {"mask_format": "rle"},
        }

    try:
        result = call_sam3_endpoint(payload, endpoint)
    finally:
        if gcs_uri:
            delete_gcs_file(gcs_uri)

    masks = [
        decode_rle_mask(rle)
        for rle in result.get("predictions", [{}])[0].get("masks_rle", [])
    ]
    return img, apply_mask_overlay(img, masks), masks


def segment_by_points(
    image_path_or_url,
    points,
    endpoint,
    gcs_bucket=None,
    blur_sigma=DEFAULT_MASK_BLUR_SIGMA,
):
    """Point-click segmentation. Returns (original, overlay with points, masks)."""
    img = load_image(image_path_or_url)
    img_b64 = image_to_base64(img)
    gcs_uri = None

    if len(img_b64) * 0.75 > 1.1 * 1024 * 1024:
        if not gcs_bucket:
            raise RuntimeError("Image too large. Set GCS_BUCKET for large images.")
        buf = io.BytesIO()
        img.save(buf, format="PNG")
        gcs_uri = upload_to_gcs(buf.getvalue(), gcs_bucket, "sam3-click", "image/png")
        payload = {
            "instances": [
                {
                    "image": gcs_uri,
                    "input_points": points,
                    "mask_blur_sigma": blur_sigma,
                }
            ],
            "parameters": {"mask_format": "rle"},
        }
    else:
        payload = {
            "instances": [
                {
                    "image": img_b64,
                    "input_points": points,
                    "mask_blur_sigma": blur_sigma,
                }
            ],
            "parameters": {"mask_format": "rle"},
        }

    try:
        result = call_sam3_endpoint(payload, endpoint)
    finally:
        if gcs_uri:
            delete_gcs_file(gcs_uri)

    masks = [
        decode_rle_mask(rle)
        for rle in result.get("predictions", [{}])[0].get("masks_rle", [])
    ]
    overlay = draw_points(apply_mask_overlay(img, masks), points)
    return img, overlay, masks


def segment_video(
    video_path,
    text_prompt,
    endpoint,
    gcs_bucket,
    frame_limit=60,
    timeout=1600,
    blur_sigma=DEFAULT_MASK_BLUR_SIGMA,
):
    """Video segmentation. Returns (output_path, sample_frames, status)."""
    cap = cv2.VideoCapture(video_path)
    fps, w, h = (
        cap.get(cv2.CAP_PROP_FPS),
        int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
        int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
    )

    frames = []
    while cap.isOpened() and len(frames) < frame_limit:
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    cap.release()

    # Write temp video and upload to GCS
    tmp = tempfile.mktemp(suffix=".mp4")
    writer = cv2.VideoWriter(tmp, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
    for f in frames:
        writer.write(cv2.cvtColor(f, cv2.COLOR_RGB2BGR))
    writer.release()

    gcs_uri = upload_to_gcs(tmp, gcs_bucket, "sam3-vid", "video/mp4", is_file=True)
    os.unlink(tmp)

    try:
        payload = {
            "instances": [
                {"video": gcs_uri, "text": text_prompt, "mask_blur_sigma": blur_sigma}
            ],
            "parameters": {"mask_format": "rle"},
        }
        result = call_sam3_endpoint(payload, endpoint, timeout=max(timeout, 1600))
    finally:
        delete_gcs_file(gcs_uri)

    masks_video = result.get("predictions", [{}])[0].get("masks_rle_video", [])

    # Create output video
    out_path = tempfile.mktemp(suffix=".mp4")
    writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
    overlay_frames = []

    for i, frame_masks in enumerate(masks_video):
        if i >= len(frames):
            break
        decoded = [decode_rle_mask(rle) for rle in frame_masks] if frame_masks else []
        overlay = apply_mask_overlay(Image.fromarray(frames[i]), decoded)
        overlay_frames.append(overlay)
        writer.write(cv2.cvtColor(np.array(overlay), cv2.COLOR_RGB2BGR))
    writer.release()

    return out_path, overlay_frames, f"Processed {len(masks_video)} frames"

## Deploy SAM3 Model

In [None]:
MODEL_ID = "facebook/sam3"  # @param ["facebook/sam3"] {isTemplate:true}
ACCELERATOR_TYPE = "NVIDIA_L4"  # @param ["NVIDIA_L4", "NVIDIA_H100_80GB"] {isTemplate:true}
ACCELERATOR_COUNT = 1  # @param [1, 2, 4, 8]

In [None]:
# @markdown Set `use_dedicated_endpoint` to False if you don't want to use [dedicated endpoint](https://cloud.google.com/vertex-ai/docs/general/deployment#create-dedicated-endpoint).
use_dedicated_endpoint = True  # @param {type:"boolean"}

In [None]:
models["sam3"], endpoints["sam3"] = deploy_sam3_model(
    MODEL_ID, ACCELERATOR_TYPE, ACCELERATOR_COUNT, use_dedicated_endpoint, HF_TOKEN
)
endpoint = endpoints["sam3"]
print(f"Endpoint deployed: {endpoint.resource_name}")

### Connect to Existing Endpoint (Optional)

In [None]:
# ENDPOINT_ID = "YOUR_ENDPOINT_ID"  # Uncomment to use existing endpoint
# endpoint = aiplatform.Endpoint(f"projects/{PROJECT_ID}/locations/{REGION}/endpoints/{ENDPOINT_ID}")
# endpoints["sam3"] = endpoint

## Image Segmentation (Text-Prompted)

In [None]:
image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"  # @param {type:"string"}
text_prompt = "cat"  # @param {type:"string"}

original, segmented, masks = segment_image(image_url, text_prompt, endpoint, GCS_BUCKET)
print(f"Found {len(masks)} mask(s)")

fig, axes = plt.subplots(1, 2, figsize=(14, 5))
axes[0].imshow(original)
axes[0].set_title("Original")
axes[0].axis("off")
axes[1].imshow(segmented)
axes[1].set_title(f"'{text_prompt}' ({len(masks)} masks)")
axes[1].axis("off")
plt.tight_layout()
plt.show()

## Point-Click Segmentation

In [None]:
click_image = "http://images.cocodataset.org/val2017/000000039769.jpg"  # @param {type:"string"}
click_points = [[220, 300], [400, 350]]  # @param {type:"raw"}

original, segmented, masks = segment_by_points(
    click_image, click_points, endpoint, GCS_BUCKET
)
print(f"Found {len(masks)} mask(s) for {len(click_points)} point(s)")

fig, axes = plt.subplots(1, 2, figsize=(14, 5))
axes[0].imshow(draw_points(original, click_points))
axes[0].set_title("Click Points")
axes[0].axis("off")
axes[1].imshow(segmented)
axes[1].set_title(f"Segmentation ({len(masks)} masks)")
axes[1].axis("off")
plt.tight_layout()
plt.show()

## Video Segmentation

> **Note:** Video segmentation requires a GCS bucket.

In [None]:
video_url = "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/ForBiggerMeltdowns.mp4"  # @param {type:"string"}
video_prompt = "person"  # @param {type:"string"}
frame_limit = 60  # @param {type:"integer"}

if video_url and GCS_BUCKET:
    # Download video from URL to temporary file
    if video_url.startswith(("http://", "https://")):
        print(f"Downloading video from {video_url}...")
        video_response = requests.get(video_url, timeout=400, stream=True)
        video_response.raise_for_status()
        video_path = tempfile.mktemp(suffix=".mp4")
        with open(video_path, "wb") as f:
            for chunk in video_response.iter_content(chunk_size=8192):
                f.write(chunk)
        print(f"Downloaded to temporary file: {video_path}")
        cleanup_temp_video = True
    else:
        video_path = video_url
        cleanup_temp_video = False

    try:
        out_path, frames, status = segment_video(
            video_path, video_prompt, endpoint, GCS_BUCKET, frame_limit
        )
        print(f"{status}. Output: {out_path}")

        # Display sample frames
        if frames:
            indices = np.linspace(0, len(frames) - 1, min(6, len(frames)), dtype=int)
            fig, axes = plt.subplots(2, 3, figsize=(15, 8))
            for i, ax in enumerate(axes.flat):
                if i < len(indices):
                    ax.imshow(frames[indices[i]])
                    ax.set_title(f"Frame {indices[i]}")
                ax.axis("off")
            plt.suptitle(f"Video: '{video_prompt}'")
            plt.tight_layout()
            plt.show()
    finally:
        if cleanup_temp_video and os.path.exists(video_path):
            os.unlink(video_path)
else:
    print("Set video_url and ensure GCS_BUCKET (BUCKET_URI) is configured.")

## Clean Up

In [None]:
# @markdown  Delete the experiment models and endpoints to recycle the resources
# @markdown  and avoid unnecessary continuous charges that may incur.

# Undeploy model and delete endpoint.
for endpoint in endpoints.values():
    endpoint.delete(force=True)

# Delete models.
for model in models.values():
    model.delete()

In [None]:
# @markdown Delete temporary GCS buckets.

delete_bucket = False  # @param {type:"boolean"}
if delete_bucket:
    ! gsutil -m rm -r $BUCKET_NAME