# Lab 2.2.6: Segment Anything Model (SAM) Integration

**Module:** 2.2 - Computer Vision  
**Time:** 2 hours  
**Difficulty:** ‚≠ê‚≠ê‚≠ê

---

## üéØ Learning Objectives

By the end of this notebook, you will:
- [ ] Understand what makes SAM a "foundation model" for segmentation
- [ ] Use SAM for automatic mask generation
- [ ] Create interactive segmentation with point/box prompts
- [ ] Leverage DGX Spark's 128GB memory for efficient SAM inference

---

## üìö Prerequisites

- Completed: Tasks 7.1-7.5
- Knowledge of: Image segmentation, transformers

---

## üåç Real-World Context

**SAM is revolutionizing image editing and analysis:**

- üì∏ **Photo editing**: One-click background removal, object selection
- üè• **Medical imaging**: Quickly annotate organs, tumors, cells
- üõ∞Ô∏è **Satellite imagery**: Segment buildings, roads, vegetation
- üé¨ **Video production**: Object tracking and rotoscoping
- ü§ñ **Robotics**: Understanding scene for manipulation

---

## üßí ELI5: What is Segment Anything Model?

> **Imagine a magical magnifying glass...** üîç
>
> When you point it at anything in a picture and say "this one", it instantly draws a perfect outline around that thing - whether it's a cat, a cup, or a cloud!
>
> **What makes SAM special:**
> 1. **Zero-shot**: Works on objects it's never seen before
> 2. **Promptable**: Tell it WHAT to segment (point, box, or text)
> 3. **Foundation model**: Trained on 11 million images, 1.1 billion masks
>
> It's like having an expert annotator who can instantly segment anything you point at!

### SAM Architecture Overview

```
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ                    Segment Anything Model                    ‚îÇ
‚îú‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î§
‚îÇ                                                             ‚îÇ
‚îÇ  Image ‚îÄ‚îÄ‚ñ∫ Image Encoder ‚îÄ‚îÄ‚ñ∫ Image Embedding                ‚îÇ
‚îÇ            (ViT-H/14)          (256√ó64√ó64)                  ‚îÇ
‚îÇ                                      ‚îÇ                      ‚îÇ
‚îÇ                                      ‚ñº                      ‚îÇ
‚îÇ  Prompt ‚îÄ‚îÄ‚ñ∫ Prompt Encoder ‚îÄ‚îÄ‚ñ∫ Prompt Embedding             ‚îÇ
‚îÇ  (points, boxes, masks)                ‚îÇ                    ‚îÇ
‚îÇ                                        ‚ñº                    ‚îÇ
‚îÇ                              ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê            ‚îÇ
‚îÇ                              ‚îÇ  Mask Decoder   ‚îÇ            ‚îÇ
‚îÇ                              ‚îÇ  (lightweight)  ‚îÇ            ‚îÇ
‚îÇ                              ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò            ‚îÇ
‚îÇ                                       ‚îÇ                     ‚îÇ
‚îÇ                                       ‚ñº                     ‚îÇ
‚îÇ                              Predicted Masks + IoU Scores   ‚îÇ
‚îÇ                                                             ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
```

**Key insight**: The image encoder is heavy (runs once), but the mask decoder is lightweight (runs many times for different prompts on the same image)!

In [None]:
# Setup
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import urllib.request
from pathlib import Path
import time
from typing import List, Tuple, Optional, Dict

# Check cv2 installation
try:
    import cv2
except ImportError:
    print("‚ö†Ô∏è OpenCV not found. Installing...")
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "opencv-python", "-q"])
    import cv2

# Check device and memory
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üñ•Ô∏è  Device: {device}")
if torch.cuda.is_available():
    print(f"üéÆ GPU: {torch.cuda.get_device_name(0)}")
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"üíæ GPU Memory: {gpu_mem:.1f} GB")
    
    # DGX Spark advantage!
    if gpu_mem > 100:
        print(f"üöÄ DGX Spark detected! You can load the largest SAM model (ViT-H)!")

In [None]:
# Install segment-anything if needed (with ARM64/DGX Spark compatibility)
import sys
import subprocess

try:
    from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
    print("‚úÖ SAM library already installed!")
except ImportError:
    print("‚ö†Ô∏è segment-anything not found.")
    print("üì¶ Installing segment-anything...")
    print("   (This may take a moment on first run)")
    
    try:
        # Install segment-anything from GitHub
        result = subprocess.run(
            [sys.executable, "-m", "pip", "install", 
             "git+https://github.com/facebookresearch/segment-anything.git", "-q"],
            capture_output=True, text=True, timeout=300
        )
        
        if result.returncode != 0:
            print(f"‚ùå Installation failed: {result.stderr}")
            print("\nüîß For DGX Spark ARM64, try installing manually in your NGC container:")
            print("   pip install git+https://github.com/facebookresearch/segment-anything.git")
            raise ImportError("segment-anything installation failed")
        
        from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
        print("‚úÖ SAM library installed successfully!")
        
    except subprocess.TimeoutExpired:
        print("‚ùå Installation timed out.")
        print("\nüîß Please install manually before running this notebook:")
        print("   pip install git+https://github.com/facebookresearch/segment-anything.git")
        raise ImportError("segment-anything installation timed out")
    except Exception as e:
        print(f"‚ùå Installation error: {e}")
        print("\nüîß Please install manually before running this notebook:")
        print("   pip install git+https://github.com/facebookresearch/segment-anything.git")
        raise

---

## Part 1: Loading SAM Models

SAM comes in three sizes:

In [None]:
# SAM model variants
sam_models = {
    'vit_b': {'checkpoint': 'sam_vit_b_01ec64.pth', 'size': '375 MB', 'speed': 'Fast'},
    'vit_l': {'checkpoint': 'sam_vit_l_0b3195.pth', 'size': '1.2 GB', 'speed': 'Medium'},
    'vit_h': {'checkpoint': 'sam_vit_h_4b8939.pth', 'size': '2.5 GB', 'speed': 'Slow (Best quality)'},
}

print("üìä SAM Model Variants:")
print("="*60)
print(f"{'Model':<12} {'Checkpoint Size':>18} {'Speed':>25}")
print("-"*60)
for name, info in sam_models.items():
    print(f"{name:<12} {info['size']:>18} {info['speed']:>25}")

print("\nüí° For DGX Spark with 128GB, we can easily use ViT-H!")

In [None]:
def download_sam_checkpoint(model_type: str = 'vit_b') -> Path:
    """
    Download SAM checkpoint if not present.
    
    Args:
        model_type: One of 'vit_b', 'vit_l', 'vit_h'
    """
    checkpoints = {
        'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
        'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
        'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth',
    }
    
    checkpoint_dir = Path('../data/sam_checkpoints')
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    
    url = checkpoints[model_type]
    filename = url.split('/')[-1]
    filepath = checkpoint_dir / filename
    
    if not filepath.exists():
        print(f"üì• Downloading {model_type} checkpoint...")
        urllib.request.urlretrieve(url, filepath)
        print(f"‚úÖ Downloaded to {filepath}")
    else:
        print(f"‚úÖ Checkpoint already exists: {filepath}")
    
    return filepath

# Download the base model (fast to download, good for demos)
# For best quality on DGX Spark, use 'vit_h'
model_type = 'vit_b'  # Change to 'vit_h' for best quality
checkpoint_path = download_sam_checkpoint(model_type)

In [None]:
# Clear buffer cache before loading large model (DGX Spark best practice)
import subprocess
import gc

print("üßπ Preparing for model load...")

# Clear Python garbage
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

# Try to clear system buffer cache (requires sudo, optional)
try:
    subprocess.run(['sudo', 'sh', '-c', 'sync; echo 3 > /proc/sys/vm/drop_caches'], 
                   check=True, capture_output=True, timeout=10)
    print("‚úÖ System buffer cache cleared")
except Exception as e:
    print(f"‚ÑπÔ∏è  Buffer cache clear skipped (optional): {type(e).__name__}")

# Load SAM model
print(f"\nüîß Loading SAM {model_type}...")
start_time = time.time()

sam = sam_model_registry[model_type](checkpoint=str(checkpoint_path))
sam.to(device=device)

load_time = time.time() - start_time
print(f"‚úÖ Model loaded in {load_time:.1f}s")

# Memory usage
if torch.cuda.is_available():
    mem_allocated = torch.cuda.memory_allocated() / 1e9
    mem_reserved = torch.cuda.memory_reserved() / 1e9
    print(f"üíæ GPU Memory Used: {mem_allocated:.1f} GB (Reserved: {mem_reserved:.1f} GB)")

---

## Part 2: Download Sample Images

In [None]:
def download_sample_images():
    """
    Download sample images for SAM demos with fallback support.
    
    Uses multiple URL sources for reliability and creates placeholder
    images if all downloads fail.
    """
    sample_dir = Path('../data/sam_samples')
    sample_dir.mkdir(parents=True, exist_ok=True)
    
    # URLs with fallbacks for reliability
    urls = {
        'dogs.jpg': [
            'https://upload.wikimedia.org/wikipedia/commons/thumb/d/d9/Collage_of_Nine_Dogs.jpg/800px-Collage_of_Nine_Dogs.jpg',
            'https://upload.wikimedia.org/wikipedia/commons/thumb/2/26/YellowLabradorLooking_new.jpg/1200px-YellowLabradorLooking_new.jpg',
        ],
        'groceries.jpg': [
            'https://images.unsplash.com/photo-1542838132-92c53300491e?w=800',
            'https://upload.wikimedia.org/wikipedia/commons/thumb/5/5a/Fruits_and_vegetables.jpg/800px-Fruits_and_vegetables.jpg',
        ],
        'room.jpg': [
            'https://images.unsplash.com/photo-1586023492125-27b2c045efd7?w=800',
            'https://upload.wikimedia.org/wikipedia/commons/thumb/5/5e/Guestroom_at_the_Westin_Seattle.jpg/1200px-Guestroom_at_the_Westin_Seattle.jpg',
        ],
    }
    
    # Custom headers to avoid blocking
    headers = {
        'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
    }
    
    images = {}
    for name, url_list in urls.items():
        filepath = sample_dir / name
        
        if filepath.exists():
            print(f"‚úÖ {name} already exists")
            images[name] = filepath
            continue
        
        downloaded = False
        for url in url_list:
            try:
                print(f"üì• Downloading {name}...")
                request = urllib.request.Request(url, headers=headers)
                with urllib.request.urlopen(request, timeout=30) as response:
                    with open(filepath, 'wb') as f:
                        f.write(response.read())
                
                # Verify it's a valid image
                img = Image.open(filepath)
                img.verify()
                downloaded = True
                print(f"   ‚úÖ Downloaded from {url[:50]}...")
                break
            except Exception as e:
                print(f"   ‚ö†Ô∏è Failed: {type(e).__name__}")
                if filepath.exists():
                    filepath.unlink()  # Remove corrupted file
        
        if not downloaded:
            print(f"   üìù Creating placeholder for {name}")
            # Create a colorful placeholder image
            h, w = 600, 800
            placeholder = np.zeros((h, w, 3), dtype=np.uint8)
            # Add gradient background
            for i in range(h):
                placeholder[i, :, 0] = int(255 * i / h)  # Red gradient
                placeholder[i, :, 2] = int(255 * (1 - i / h))  # Blue gradient
            placeholder[:, :, 1] = 128  # Green constant
            # Add some shapes for segmentation testing
            cv2.rectangle(placeholder, (100, 100), (300, 300), (255, 255, 255), -1)
            cv2.circle(placeholder, (500, 300), 100, (0, 255, 0), -1)
            cv2.rectangle(placeholder, (400, 400), (700, 550), (255, 0, 255), -1)
            Image.fromarray(placeholder).save(filepath)
        
        images[name] = filepath
    
    return images

sample_images = download_sample_images()
print(f"\n‚úÖ Sample images ready: {list(sample_images.keys())}")

In [None]:
def load_image(path: Path) -> Tuple[np.ndarray, Image.Image]:
    """Load image and return both numpy array (for SAM) and PIL image (for display)."""
    pil_image = Image.open(path).convert('RGB')
    np_image = np.array(pil_image)
    return np_image, pil_image

# Load a sample image
sample_path = list(sample_images.values())[0]
image_np, image_pil = load_image(sample_path)

plt.figure(figsize=(10, 8))
plt.imshow(image_np)
plt.title(f'Sample Image: {sample_path.name}\nSize: {image_np.shape}')
plt.axis('off')
plt.show()

---

## Part 3: Automatic Mask Generation

SAM can automatically segment everything in an image!

In [None]:
# Create automatic mask generator
mask_generator = SamAutomaticMaskGenerator(
    sam,
    points_per_side=32,  # Grid density for point prompts
    pred_iou_thresh=0.88,  # Confidence threshold
    stability_score_thresh=0.95,
    crop_n_layers=1,  # Multi-scale
    crop_n_points_downscale_factor=2,
    min_mask_region_area=100,  # Filter tiny masks
)

print("‚úÖ Automatic mask generator created!")

In [None]:
# Generate masks automatically
print("üîç Generating masks automatically...")
start_time = time.time()

masks = mask_generator.generate(image_np)

gen_time = time.time() - start_time
print(f"‚úÖ Generated {len(masks)} masks in {gen_time:.1f}s")

# Print mask statistics
print(f"\nüìä Mask Statistics:")
areas = [m['area'] for m in masks]
print(f"   Area range: {min(areas):,} - {max(areas):,} pixels")
print(f"   Average area: {np.mean(areas):,.0f} pixels")

In [None]:
def show_masks(image: np.ndarray, masks: List[Dict], figsize: Tuple = (12, 10)):
    """
    Visualize automatically generated masks.
    """
    plt.figure(figsize=figsize)
    plt.imshow(image)
    
    # Sort masks by area (largest first)
    sorted_masks = sorted(masks, key=lambda x: x['area'], reverse=True)
    
    ax = plt.gca()
    for mask_data in sorted_masks:
        mask = mask_data['segmentation']
        color = np.random.random(3)
        
        # Create colored mask
        colored_mask = np.zeros((*mask.shape, 4))
        colored_mask[mask] = [*color, 0.5]  # RGBA with alpha
        
        ax.imshow(colored_mask)
    
    plt.title(f'üé® Automatic Segmentation: {len(masks)} objects found')
    plt.axis('off')
    plt.tight_layout()
    plt.show()

show_masks(image_np, masks)

In [None]:
def show_individual_masks(image: np.ndarray, masks: List[Dict], num_show: int = 9):
    """
    Show individual masks in a grid.
    """
    # Sort by area and take top N
    sorted_masks = sorted(masks, key=lambda x: x['area'], reverse=True)[:num_show]
    
    rows = int(np.ceil(np.sqrt(num_show)))
    cols = int(np.ceil(num_show / rows))
    
    fig, axes = plt.subplots(rows, cols, figsize=(15, 12))
    axes = axes.flatten()
    
    for idx, (ax, mask_data) in enumerate(zip(axes, sorted_masks)):
        mask = mask_data['segmentation']
        
        # Create masked image
        masked_image = image.copy()
        masked_image[~mask] = 255  # White background
        
        ax.imshow(masked_image)
        ax.set_title(f"Area: {mask_data['area']:,}\nIoU: {mask_data['predicted_iou']:.2f}", fontsize=9)
        ax.axis('off')
    
    # Hide empty subplots
    for ax in axes[len(sorted_masks):]:
        ax.axis('off')
    
    plt.suptitle('üîç Top Individual Masks (by area)', fontsize=14)
    plt.tight_layout()
    plt.show()

show_individual_masks(image_np, masks)

---

## Part 4: Interactive Segmentation with Prompts

SAM's real power: segment specific objects by pointing at them!

In [None]:
# Create predictor for interactive prompts
predictor = SamPredictor(sam)

# Set the image (encode it once, reuse for multiple prompts)
print("üîß Encoding image...")
start_time = time.time()

predictor.set_image(image_np)

encode_time = time.time() - start_time
print(f"‚úÖ Image encoded in {encode_time:.1f}s")
print("   Now we can make fast predictions with different prompts!")

### Point Prompts

Click on an object to segment it!

In [None]:
def segment_with_point(
    predictor: SamPredictor,
    image: np.ndarray,
    point: Tuple[int, int],
    point_label: int = 1  # 1 for foreground, 0 for background
):
    """
    Segment object at a given point.
    
    Args:
        point: (x, y) coordinates
        point_label: 1 for "include this", 0 for "exclude this"
    """
    input_point = np.array([point])
    input_label = np.array([point_label])
    
    masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=True  # Return multiple masks with different granularity
    )
    
    return masks, scores

def visualize_point_segmentation(
    image: np.ndarray,
    point: Tuple[int, int],
    masks: np.ndarray,
    scores: np.ndarray
):
    """
    Visualize segmentation results from a point prompt.
    """
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    
    # Original with point
    axes[0].imshow(image)
    axes[0].scatter(point[0], point[1], c='red', s=200, marker='*', edgecolors='white', linewidths=2)
    axes[0].set_title('Input Point')
    axes[0].axis('off')
    
    # Show all 3 masks
    colors = ['#FF6B6B', '#4ECDC4', '#45B7D1']
    titles = ['Small', 'Medium', 'Large']
    
    for idx, (mask, score, ax) in enumerate(zip(masks, scores, axes[1:])):
        ax.imshow(image)
        
        # Create colored mask
        colored_mask = np.zeros((*mask.shape, 4))
        color_rgb = plt.cm.colors.to_rgb(colors[idx])
        colored_mask[mask] = [*color_rgb, 0.6]
        ax.imshow(colored_mask)
        
        ax.scatter(point[0], point[1], c='red', s=100, marker='*', edgecolors='white')
        ax.set_title(f'{titles[idx]} mask\nScore: {score:.2f}')
        ax.axis('off')
    
    plt.suptitle('üëÜ Point Prompt ‚Üí Multiple Mask Options', fontsize=14)
    plt.tight_layout()
    plt.show()

# Try a point prompt (adjust coordinates based on your image)
h, w = image_np.shape[:2]
point = (w // 2, h // 2)  # Center of image

masks, scores = segment_with_point(predictor, image_np, point)
visualize_point_segmentation(image_np, point, masks, scores)

### Box Prompts

Draw a bounding box to segment the object inside!

In [None]:
def segment_with_box(
    predictor: SamPredictor,
    box: Tuple[int, int, int, int]  # x1, y1, x2, y2
):
    """
    Segment object within a bounding box.
    """
    input_box = np.array(box)
    
    masks, scores, logits = predictor.predict(
        box=input_box,
        multimask_output=False  # Single mask for boxes
    )
    
    return masks[0], scores[0]

def visualize_box_segmentation(
    image: np.ndarray,
    box: Tuple[int, int, int, int],
    mask: np.ndarray,
    score: float
):
    """
    Visualize box prompt segmentation.
    """
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    x1, y1, x2, y2 = box
    
    # Original with box
    axes[0].imshow(image)
    rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, fill=False, edgecolor='lime', linewidth=3)
    axes[0].add_patch(rect)
    axes[0].set_title('Input Box')
    axes[0].axis('off')
    
    # Mask overlay
    axes[1].imshow(image)
    colored_mask = np.zeros((*mask.shape, 4))
    colored_mask[mask] = [0.2, 0.8, 0.8, 0.6]
    axes[1].imshow(colored_mask)
    rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, fill=False, edgecolor='lime', linewidth=2)
    axes[1].add_patch(rect)
    axes[1].set_title(f'Segmentation (Score: {score:.2f})')
    axes[1].axis('off')
    
    # Extracted object
    extracted = image.copy()
    extracted[~mask] = 255
    axes[2].imshow(extracted)
    axes[2].set_title('Extracted Object')
    axes[2].axis('off')
    
    plt.suptitle('üì¶ Box Prompt ‚Üí Precise Segmentation', fontsize=14)
    plt.tight_layout()
    plt.show()

# Try a box prompt (adjust based on your image)
h, w = image_np.shape[:2]
box = (w//4, h//4, 3*w//4, 3*h//4)  # Central box

mask, score = segment_with_box(predictor, box)
visualize_box_segmentation(image_np, box, mask, score)

### Combining Multiple Prompts

Use positive and negative points together for precise control!

In [None]:
def segment_with_multi_prompts(
    predictor: SamPredictor,
    positive_points: List[Tuple[int, int]],
    negative_points: List[Tuple[int, int]] = None,
    box: Optional[Tuple[int, int, int, int]] = None
):
    """
    Segment with multiple prompt types.
    
    Args:
        positive_points: Points that should be included
        negative_points: Points that should be excluded
        box: Optional bounding box
    """
    if negative_points is None:
        negative_points = []
    
    all_points = positive_points + negative_points
    labels = [1] * len(positive_points) + [0] * len(negative_points)
    
    input_points = np.array(all_points)
    input_labels = np.array(labels)
    input_box = np.array(box) if box else None
    
    masks, scores, _ = predictor.predict(
        point_coords=input_points if len(all_points) > 0 else None,
        point_labels=input_labels if len(all_points) > 0 else None,
        box=input_box,
        multimask_output=False
    )
    
    return masks[0], scores[0]

def visualize_multi_prompt(
    image: np.ndarray,
    positive_points: List[Tuple[int, int]],
    negative_points: List[Tuple[int, int]],
    mask: np.ndarray
):
    """
    Visualize multi-prompt segmentation.
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    # Prompts
    axes[0].imshow(image)
    for p in positive_points:
        axes[0].scatter(p[0], p[1], c='green', s=200, marker='*', edgecolors='white', linewidths=2)
    for p in negative_points:
        axes[0].scatter(p[0], p[1], c='red', s=200, marker='x', linewidths=3)
    axes[0].set_title('‚úì Positive (green) ‚úó Negative (red) points')
    axes[0].axis('off')
    
    # Result
    axes[1].imshow(image)
    colored_mask = np.zeros((*mask.shape, 4))
    colored_mask[mask] = [0.3, 0.7, 0.3, 0.6]
    axes[1].imshow(colored_mask)
    axes[1].set_title('Segmentation Result')
    axes[1].axis('off')
    
    plt.suptitle('üéØ Multi-Prompt Segmentation', fontsize=14)
    plt.tight_layout()
    plt.show()

# Example with positive and negative points
h, w = image_np.shape[:2]
positive = [(w//2, h//2)]  # Include center
negative = [(50, 50), (w-50, h-50)]  # Exclude corners

mask, _ = segment_with_multi_prompts(predictor, positive, negative)
visualize_multi_prompt(image_np, positive, negative, mask)

---

## Part 5: Benchmarking on DGX Spark

In [None]:
def benchmark_sam(predictor: SamPredictor, image: np.ndarray, num_runs: int = 10):
    """
    Benchmark SAM inference speed.
    """
    results = {
        'encode_time': [],
        'point_inference': [],
        'box_inference': [],
    }
    
    h, w = image.shape[:2]
    point = np.array([[w//2, h//2]])
    label = np.array([1])
    box = np.array([w//4, h//4, 3*w//4, 3*h//4])
    
    # Benchmark image encoding
    print("üèéÔ∏è Benchmarking image encoding...")
    for _ in range(num_runs):
        torch.cuda.synchronize() if torch.cuda.is_available() else None
        start = time.perf_counter()
        predictor.set_image(image)
        torch.cuda.synchronize() if torch.cuda.is_available() else None
        results['encode_time'].append(time.perf_counter() - start)
    
    # Benchmark point prediction (after encoding)
    print("üèéÔ∏è Benchmarking point prediction...")
    predictor.set_image(image)  # Encode once
    for _ in range(num_runs):
        torch.cuda.synchronize() if torch.cuda.is_available() else None
        start = time.perf_counter()
        predictor.predict(point_coords=point, point_labels=label, multimask_output=True)
        torch.cuda.synchronize() if torch.cuda.is_available() else None
        results['point_inference'].append(time.perf_counter() - start)
    
    # Benchmark box prediction
    print("üèéÔ∏è Benchmarking box prediction...")
    for _ in range(num_runs):
        torch.cuda.synchronize() if torch.cuda.is_available() else None
        start = time.perf_counter()
        predictor.predict(box=box, multimask_output=False)
        torch.cuda.synchronize() if torch.cuda.is_available() else None
        results['box_inference'].append(time.perf_counter() - start)
    
    return results

# Run benchmark
benchmark_results = benchmark_sam(predictor, image_np, num_runs=10)

# Display results
print("\n" + "="*60)
print("üìä SAM BENCHMARK RESULTS ON DGX SPARK")
print("="*60)
print(f"Image size: {image_np.shape}")
print(f"Model: SAM {model_type}")
print("-"*60)
print(f"{'Operation':<25} {'Mean (ms)':>15} {'Std (ms)':>15}")
print("-"*60)

for op, times in benchmark_results.items():
    mean_ms = np.mean(times) * 1000
    std_ms = np.std(times) * 1000
    print(f"{op:<25} {mean_ms:>14.1f} {std_ms:>14.1f}")

print("="*60)
print(f"\nüí° Key insight: After encoding ({np.mean(benchmark_results['encode_time'])*1000:.0f}ms),")
print(f"   predictions are fast ({np.mean(benchmark_results['point_inference'])*1000:.0f}ms)!")
print(f"   This enables interactive applications.")

---

## Part 6: Practical Applications

In [None]:
def background_removal(image: np.ndarray, predictor: SamPredictor, point: Tuple[int, int]) -> np.ndarray:
    """
    Remove background from image using SAM.
    
    Returns:
        RGBA image with transparent background
    """
    predictor.set_image(image)
    
    masks, scores, _ = predictor.predict(
        point_coords=np.array([point]),
        point_labels=np.array([1]),
        multimask_output=True
    )
    
    # Use highest scoring mask
    best_mask = masks[scores.argmax()]
    
    # Create RGBA image
    rgba = np.zeros((*image.shape[:2], 4), dtype=np.uint8)
    rgba[:, :, :3] = image
    rgba[:, :, 3] = (best_mask * 255).astype(np.uint8)
    
    return rgba, best_mask

# Demo: Background removal
h, w = image_np.shape[:2]
rgba_result, mask = background_removal(image_np, predictor, (w//2, h//2))

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(image_np)
axes[0].set_title('Original')
axes[0].axis('off')

axes[1].imshow(mask, cmap='gray')
axes[1].set_title('Mask')
axes[1].axis('off')

# Show with checkerboard background for transparency
checkerboard = np.zeros((h, w, 3), dtype=np.uint8)
checkerboard[::20, ::20] = 200
checkerboard[10::20, 10::20] = 200
result_vis = checkerboard.copy()
result_vis[mask] = image_np[mask]

axes[2].imshow(result_vis)
axes[2].set_title('Background Removed')
axes[2].axis('off')

plt.suptitle('‚úÇÔ∏è One-Click Background Removal with SAM', fontsize=14)
plt.tight_layout()
plt.show()

---

## ‚úã Try It Yourself

1. **Try your own images**: Load an image of your choice and segment objects
2. **Compare model sizes**: If you have time, compare ViT-B vs ViT-H quality
3. **Build an annotation tool**: Create a function that takes multiple box prompts and segments all objects

<details>
<summary>üí° Hint for annotation tool</summary>

```python
def annotate_multiple_objects(image, boxes):
    predictor.set_image(image)  # Encode once
    masks = []
    for box in boxes:
        mask, _ = segment_with_box(predictor, box)
        masks.append(mask)
    return masks
```

</details>

In [None]:
# YOUR CODE HERE



---

## ‚ö†Ô∏è Common Mistakes

### Mistake 1: Re-encoding for each prompt

```python
# ‚ùå Wrong: Encoding image every time (slow!)
for point in points:
    predictor.set_image(image)  # Expensive!
    predictor.predict(point=point)

# ‚úÖ Right: Encode once, predict many times
predictor.set_image(image)  # Do this once
for point in points:
    predictor.predict(point=point)  # Fast!
```
**Why:** Image encoding is the expensive operation (~1s). Predictions are fast (~50ms).

### Mistake 2: Wrong input format

```python
# ‚ùå Wrong: Normalized tensor
image = transforms.ToTensor()(pil_image)  # [0, 1] range
predictor.set_image(image.numpy())

# ‚úÖ Right: RGB numpy array with 0-255 values
image = np.array(pil_image)  # [0, 255] range
predictor.set_image(image)
```
**Why:** SAM expects raw 8-bit RGB images.

### Mistake 3: Not using multimask_output correctly

```python
# For points: Use multimask=True (get options)
masks, scores, _ = predictor.predict(point_coords=point, multimask_output=True)
best_mask = masks[scores.argmax()]

# For boxes: Use multimask=False (single precise mask)
masks, _, _ = predictor.predict(box=box, multimask_output=False)
mask = masks[0]
```
**Why:** Points are ambiguous (small/medium/large objects), boxes are precise.

---

## üéâ Checkpoint

You've learned:
- ‚úÖ What makes SAM a foundation model for segmentation
- ‚úÖ Automatic mask generation for entire images
- ‚úÖ Interactive segmentation with points and boxes
- ‚úÖ Combining positive and negative prompts
- ‚úÖ Practical applications like background removal

---

## üöÄ Challenge (Optional)

**Build a simple "Magic Wand" tool:**

Create a function that:
1. Takes an image and a click position
2. Returns the best mask for that click
3. Allows iterative refinement with more clicks

Bonus: Add a "grow" and "shrink" function that returns different granularity masks.

<details>
<summary>üí° Starting Code</summary>

```python
class MagicWand:
    def __init__(self, sam_model):
        self.predictor = SamPredictor(sam_model)
        self.current_mask = None
        self.positive_points = []
        self.negative_points = []
    
    def set_image(self, image):
        self.predictor.set_image(image)
        self.current_mask = None
        self.positive_points = []
        self.negative_points = []
    
    def click(self, point, is_positive=True):
        if is_positive:
            self.positive_points.append(point)
        else:
            self.negative_points.append(point)
        self._update_mask()
        return self.current_mask
    
    def _update_mask(self):
        # ... predict with all accumulated points
        pass
```

</details>

In [None]:
# YOUR CHALLENGE CODE HERE



---

## üìñ Further Reading

- [SAM Paper](https://arxiv.org/abs/2304.02643) - "Segment Anything"
- [SAM GitHub](https://github.com/facebookresearch/segment-anything)
- [SAM Demo](https://segment-anything.com/demo)
- [SAM 2](https://ai.meta.com/sam2/) - Video segmentation

---

## üßπ Cleanup

In [None]:
# Clear GPU memory
import gc

del sam, predictor, mask_generator
torch.cuda.empty_cache()
gc.collect()

print("‚úÖ Cleanup complete!")
if torch.cuda.is_available():
    print(f"üíæ GPU Memory Free: {torch.cuda.mem_get_info()[0] / 1e9:.1f} GB")

---

## üéì Module 7 Complete!

Congratulations! You've completed the Computer Vision module. You've learned:

1. **CNN Architectures**: From LeNet to ResNet
2. **Transfer Learning**: Fine-tuning pre-trained models
3. **Object Detection**: Using YOLOv8 for real-time detection
4. **Segmentation**: U-Net for semantic segmentation
5. **Vision Transformers**: ViT from scratch
6. **Foundation Models**: SAM for zero-shot segmentation

These skills form the backbone of modern computer vision applications!