[![Labellerr](https://storage.googleapis.com/labellerr-cdn/%200%20Labellerr%20template/notebook.webp)](https://www.labellerr.com)

# Segment Anything Model 2 (SAM2)

---

[![labellerr](https://img.shields.io/badge/Labellerr-BLOG-black.svg)](https://www.labellerr.com/blog/<BLOG_NAME>)
[![Youtube](https://img.shields.io/badge/Labellerr-YouTube-b31b1b.svg)](https://www.youtube.com/@Labellerr)
[![Github](https://img.shields.io/badge/Labellerr-GitHub-green.svg)](https://github.com/Labellerr/Hands-On-Learning-in-Computer-Vision)
[![Scientific Paper](https://img.shields.io/badge/Official-Paper-blue.svg)](<PAPER LINK>)

# About SAM2
Segment Anything Model 2 (SAM2) is an advanced AI model designed to overcome limitations of its predecessor, SAM, by excelling in both image and video segmentation. Built on a transformer architecture with a streaming memory system, SAM2 processes video in real-time, tracking objects across frames using optical flow and temporal attention. It achieves up to 6x higher accuracy in image tasks through multi-scale feature fusion and high-resolution mask decoders, capturing fine details and small objects. Optimized for efficiency via lightweight variants and sparse attention, SAM2 reduces computational overhead. Trained on the SA-V dataset (10M+ videos), it leverages a model-in-the-loop data engine for iterative improvement and supports multi-modal prompts (clicks, text). Applications span real-time video analytics, medical imaging, and robotics.

In [None]:
! uv pip install -q torch torchvision numpy matplotlib opencv-python pillow

In [None]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
import torch
import torchvision
import requests
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("CUDA is available:", torch.cuda.is_available())

In [None]:
!uv pip install 'git+https://github.com/facebookresearch/sam2.git'

In [None]:
!mkdir -p ../checkpoints/
!wget -P ../checkpoints/ https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [None]:
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

sam2_checkpoint = "../checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)

predictor = SAM2ImagePredictor(sam2_model)

In [None]:
# !wget -O elephant.jpg http://farm8.staticflickr.com/7193/6956100130_8bfc1afaa1_z.jpg
# masks, scores = box_segmentor(
#     image_path="elephant.jpg",
#     box_coords=[ 95.1795, 156.5467, 255.3454, 307.8629]
# )

# masks, scores

In [None]:
import requests
def box_segmentor(image_path, box_coords, model_checkpoint=None, device='auto'):
    """
    Perform SAM2 segmentation using multiple box coordinates with optimized processing.
    
    Args:
        image_path (str): Path to input image
        box_coords (list): List of [x_min, y_min, x_max, y_max] coordinates
        model_checkpoint (str): Optional path to SAM2 checkpoint
        device (str): 'cuda', 'cpu' or 'auto' (default)
    
    Returns:
        tuple: (masks_list, scores_list) containing segmentation results
    """
    # Device configuration
    if device == 'auto':
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load model (if not already loaded)
    if not hasattr(box_segmentor, 'predictor'):
        checkpoint = model_checkpoint or "../checkpoints/sam2.1_hiera_large.pt"
        model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
        
        sam2_model = build_sam2(model_cfg, checkpoint, device=device)
        box_segmentor.predictor = SAM2ImagePredictor(sam2_model)
    
    # Load and preprocess image
    image = np.array(Image.open(requests.get(image_path, stream=True).raw))
    
    # Set image once for all predictions (optimization)
    with torch.inference_mode(), torch.autocast(device.type, dtype=torch.bfloat16):
        box_segmentor.predictor.set_image(image)
        
        masks_list = []
        scores_list = []
        
        for box in box_coords:
            box_array = np.array(box)[None, :]  # Add batch dimension
            
            # Correct unpacking: predict returns (masks, scores, logits)
            predicted_masks, predicted_scores, _ = box_segmentor.predictor.predict(
                box=box_array,
                multimask_output=False
            )
            
            masks_list.append(predicted_masks[0])
            scores_list.append(predicted_scores[0])
    
    return masks_list, scores_list


In [None]:
def display_masks(image, masks, alpha=0.5, random_color=True, borders=True):
    """
    Display multiple masks over an image with different colors
    
    Args:
        image: PIL.Image - Original image
        masks: List of 2D numpy arrays (float32 masks 0-1)
        alpha: float (0-1) - Transparency of mask fills
        random_color: bool - Use random colors for each mask
        borders: bool - Show border contours
    """
    # Convert image to numpy array
    img_np = np.array(image.convert("RGB"))
    h, w = img_np.shape[:2]
    
    # Create overlay canvas
    overlay = np.zeros((h, w, 4), dtype=np.float32)
    
    # Generate colors for each mask
    colors = []
    for _ in masks:
        if random_color:
            colors.append(np.append(np.random.random(3), alpha))
        else:
            colors.append(np.array([30/255, 144/255, 255/255, alpha]))
    
    # Process each mask
    for i, mask in enumerate(masks):
        # Convert mask to binary
        binary_mask = (mask > 0.5).astype(np.uint8)
        
        # Create mask overlay
        mask_overlay = np.zeros((h, w, 4))
        mask_overlay[binary_mask == 1] = colors[i]
        
        # Add borders
        if borders:
            contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            contours = [cv2.approxPolyDP(c, 0.01*cv2.arcLength(c, True), True) for c in contours]
            
            # Draw white borders (80% opacity)
            border_color = (1, 1, 1, 0.8)  # White borders
            cv2.drawContours(mask_overlay, contours, -1, border_color, 2)
        
        # Combine overlays
        overlay = np.where(mask_overlay != 0, mask_overlay, overlay)
    
    # Create final composition
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(img_np)
    ax.imshow(overlay)
    ax.axis('off')
    plt.show()

In [None]:
from PIL import Image, ImageDraw
import random

def draw_boxes_on_image(image, input_boxes, colors=None, line_width=3):
    """
    Draw bounding boxes on an image with random colors
    
    Args:
        image: PIL.Image or file path
        input_boxes: List of boxes in [x_min, y_min, x_max, y_max] format
        colors: None (random), single color, or list of colors (optional)
        line_width: Width of box borders in pixels (default: 3)
        display: Whether to display the image (default: False)
    
    Returns:
        PIL.Image with drawn boxes
    """
    # Convert to PIL Image if path provided
    if isinstance(image, str):
        image = Image.open(image).convert("RGB")
    
    # Validate input boxes
    if not all(len(box) == 4 for box in input_boxes):
        raise ValueError("Boxes must be in [x_min, y_min, x_max, y_max] format")
    
    # Convert coordinates to integers
    boxes = [tuple(map(int, box)) for box in input_boxes]
    
    # Generate random colors if none provided
    if colors is None:
        colors = [
            (
                random.randint(0, 255),
                random.randint(0, 255),
                random.randint(0, 255)
            ) for _ in input_boxes
        ]
    # Handle single color input
    elif isinstance(colors, (str, tuple)):
        colors = [colors] * len(input_boxes)
    
    # Verify color/box count match
    if len(colors) != len(input_boxes):
        raise ValueError("Number of colors must match number of boxes")
    
    # Draw boxes
    image = image.copy()
    draw = ImageDraw.Draw(image)
    for box, color in zip(boxes, colors):
        draw.rectangle(box, outline=color, width=line_width)
    
    return image

In [None]:
url1 = "https://farm8.staticflickr.com/7331/9239614558_26e5a50351_z.jpg"
img1 = Image.open(requests.get(url1, stream=True).raw)
img1

In [None]:
input_boxes = [[308.9709, 116.7895, 503.7764, 258.8303],
          [172.0639, 186.9951, 297.3459, 328.9594]]
masks, scores = box_segmentor(
    image_path=url1,
    box_coords= input_boxes
)

In [None]:
draw_boxes_on_image(
    image=img1,
    input_boxes=input_boxes,
    line_width=2,
)



In [None]:
display_masks(img1, masks, alpha=0.5, random_color=True, borders=False)

In [None]:
url2 = "http://farm6.staticflickr.com/5455/9293164411_47fae6c6cb_z.jpg"
img2 = Image.open(requests.get(url, stream=True).raw)

input_boxes = [[178.88, 97.84, 586.23, 395.83]]

masks, scores = box_segmentor(image_path = url2, box_coords = input_boxes)


In [None]:
draw_boxes_on_image(image=img2, input_boxes=input_boxes, line_width=2)


In [None]:
display_masks(img2, masks, alpha=0.5, random_color=True, borders=False)

---

## using point

In [None]:

def show_mask(mask, ax, random_color=False, borders = True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask = mask.astype(np.uint8)
    mask_image =  mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    if borders:
        import cv2
        contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 
        # Try to smooth contours
        contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
        mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) 
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))

def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca(), borders=borders)
        if point_coords is not None:
            assert input_labels is not None
            show_points(point_coords, input_labels, plt.gca())
        if box_coords is not None:
            # boxes
            show_box(box_coords, plt.gca())
        if len(scores) > 1:
            plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
        plt.axis('off')
        plt.show()

In [None]:
def point_segmentor(
    image,
    point_sets,
    label_sets,
    predictor,
    show_individual=False,
    show_combined=True,
    multimask_output=False
):
    """
    Segments objects in an image using multiple sets of points with SAM2 and visualizes the results.

    Args:
        image (np.ndarray): The input image as a numpy array (H, W, 3).
        point_sets (list of np.ndarray): List of (N_points, 2) arrays, each array is a set of (x, y) points.
        label_sets (list of np.ndarray): List of (N_points,) arrays, each array is a set of labels (1=fg, 0=bg) for corresponding points.
        predictor (SAM2ImagePredictor): Predictor object.
        show_individual (bool): If True, show each segmentation result individually.
        show_combined (bool): If True, show all masks overlaid on the image.
        multimask_output (bool): If True, returns multiple masks per prompt set.

    Returns:
        all_masks (list): List of mask arrays for each point set.
        all_scores (list): List of scores for each mask.
    """
    predictor.set_image(image)
    all_masks = []
    all_scores = []

    # Predict masks for each set of points
    for pts, lbls in zip(point_sets, label_sets):
        masks, scores, _ = predictor.predict(
            point_coords=pts,
            point_labels=lbls,
            multimask_output=multimask_output
        )
        all_masks.append(masks)
        all_scores.append(scores)

    # Visualize each mask individually
    if show_individual:
        for i, (masks, scores, pts, lbls) in enumerate(zip(all_masks, all_scores, point_sets, label_sets)):
            show_masks(
                image,
                masks,
                scores,
                point_coords=pts,
                input_labels=lbls,
                borders=True
            )

    # Visualize all best masks together
    if show_combined:
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        for masks, pts, lbls in zip(all_masks, point_sets, label_sets):
            # Use the best mask (highest score)
            best_idx = np.argmax(all_scores[all_masks.index(masks)])
            show_mask(masks[best_idx], plt.gca(), random_color=True)
            show_points(pts, lbls, plt.gca())
        plt.title("All Segmented Objects (Best Masks)")
        plt.axis('off')
        plt.show()

    return all_masks, all_scores


In [None]:
img1

In [None]:
# Example: segmenting two objects with different point prompts
point_sets = [
    np.array([[406, 187 ], [220, 258]]) # Points for object
]
label_sets = [
    np.array([1, 1])                   # Labels for object
]

masks, scores = point_segmentor(img1, point_sets, label_sets, predictor, show_combined = True)


In [None]:
img2

In [None]:
point_sets = [
    np.array([[350, 246]]) # Points for object
]
label_sets = [
    np.array([1])                   # Labels for object
]

masks, scores = point_segmentor(img2, point_sets, label_sets, predictor, show_combined = True)