# Task 7.6 Solution: SAM Integration

**Module:** 7 - Computer Vision  
**Type:** Solution Notebook

---

This notebook contains solutions for Segment Anything Model (SAM) exercises, including an interactive Magic Wand tool.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Tuple, Optional, Dict

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## Exercise Solution: Magic Wand Tool

An interactive segmentation tool using SAM that supports:
- Positive clicks (include region)
- Negative clicks (exclude region)
- Undo/Redo functionality
- Grow/Shrink mask options

In [None]:
class MagicWand:
    """
    Interactive "Magic Wand" tool using SAM.
    
    Allows iterative refinement with positive and negative clicks,
    similar to Photoshop's magic wand tool but with SAM's power.
    
    Usage:
        wand = MagicWand(sam_predictor)
        wand.set_image(image)
        
        # Click to select
        wand.click((x, y), is_positive=True)   # Include this region
        wand.click((x, y), is_positive=False)  # Exclude this region
        
        # Refine
        wand.grow()    # Get larger mask option
        wand.shrink()  # Get smaller mask option
        wand.undo()    # Remove last click
        
        # Visualize
        wand.visualize()
    """
    
    def __init__(self, sam_predictor):
        """
        Args:
            sam_predictor: SamPredictor instance from segment-anything
        """
        self.predictor = sam_predictor
        self.reset()
    
    def reset(self):
        """Clear all clicks and masks."""
        self.positive_points = []
        self.negative_points = []
        self.current_mask = None
        self.all_masks = []  # History of masks for grow/shrink
        self.mask_history = []  # For undo functionality
    
    def set_image(self, image: np.ndarray):
        """
        Set the image to segment.
        
        Args:
            image: RGB image as numpy array [H, W, 3]
        """
        self.predictor.set_image(image)
        self.image = image
        self.reset()
        print(f"Image set: {image.shape}")
    
    def click(self, point: Tuple[int, int], is_positive: bool = True) -> np.ndarray:
        """
        Add a click point and update the mask.
        
        Args:
            point: (x, y) coordinates in image space
            is_positive: True for "include", False for "exclude"
        
        Returns:
            Updated binary mask
        """
        # Save state for undo
        self.mask_history.append({
            'positive': self.positive_points.copy(),
            'negative': self.negative_points.copy(),
            'mask': self.current_mask
        })
        
        if is_positive:
            self.positive_points.append(point)
            print(f"Added positive point at {point}")
        else:
            self.negative_points.append(point)
            print(f"Added negative point at {point}")
        
        self._update_mask()
        return self.current_mask
    
    def undo(self) -> Optional[np.ndarray]:
        """Undo last click."""
        if self.mask_history:
            state = self.mask_history.pop()
            self.positive_points = state['positive']
            self.negative_points = state['negative']
            self.current_mask = state['mask']
            print("Undone last action")
        else:
            print("Nothing to undo")
        
        return self.current_mask
    
    def grow(self) -> Optional[np.ndarray]:
        """
        Get a larger mask (if available from multi-mask output).
        
        SAM outputs multiple masks at different granularities.
        This selects the largest one.
        """
        if len(self.all_masks) > 1:
            areas = [m.sum() for m in self.all_masks]
            largest_idx = areas.index(max(areas))
            self.current_mask = self.all_masks[largest_idx]
            print(f"Grew mask to size {self.current_mask.sum()} pixels")
        else:
            print("Only one mask available")
        return self.current_mask
    
    def shrink(self) -> Optional[np.ndarray]:
        """
        Get a smaller mask (if available from multi-mask output).
        
        SAM outputs multiple masks at different granularities.
        This selects the smallest one.
        """
        if len(self.all_masks) > 1:
            areas = [m.sum() for m in self.all_masks]
            smallest_idx = areas.index(min(areas))
            self.current_mask = self.all_masks[smallest_idx]
            print(f"Shrunk mask to size {self.current_mask.sum()} pixels")
        else:
            print("Only one mask available")
        return self.current_mask
    
    def _update_mask(self):
        """Update mask based on all accumulated points."""
        if not self.positive_points and not self.negative_points:
            self.current_mask = None
            return
        
        # Combine all points with labels
        all_points = self.positive_points + self.negative_points
        labels = [1] * len(self.positive_points) + [0] * len(self.negative_points)
        
        # Run SAM prediction
        masks, scores, _ = self.predictor.predict(
            point_coords=np.array(all_points),
            point_labels=np.array(labels),
            multimask_output=True
        )
        
        # Store all masks for grow/shrink functionality
        self.all_masks = [masks[i] for i in range(len(masks))]
        
        # Use highest scoring mask as default
        best_idx = scores.argmax()
        self.current_mask = masks[best_idx]
        
        print(f"Updated mask (score: {scores[best_idx]:.3f}, area: {self.current_mask.sum()} pixels)")
    
    def get_mask(self) -> Optional[np.ndarray]:
        """Get the current mask."""
        return self.current_mask
    
    def visualize(self, figsize: Tuple[int, int] = (14, 6)):
        """Visualize current state with clicks and mask overlay."""
        fig, axes = plt.subplots(1, 2, figsize=figsize)
        
        # Image with click points
        axes[0].imshow(self.image)
        
        # Plot positive points (green stars)
        for p in self.positive_points:
            axes[0].scatter(p[0], p[1], c='green', s=200, marker='*', 
                           edgecolors='white', linewidths=2, label='Positive')
        
        # Plot negative points (red X)
        for p in self.negative_points:
            axes[0].scatter(p[0], p[1], c='red', s=200, marker='x', 
                           linewidths=3, label='Negative')
        
        axes[0].set_title(f'Clicks: {len(self.positive_points)} positive, {len(self.negative_points)} negative')
        axes[0].axis('off')
        
        # Image with mask overlay
        axes[1].imshow(self.image)
        
        if self.current_mask is not None:
            # Create semi-transparent colored mask
            colored_mask = np.zeros((*self.current_mask.shape, 4))
            colored_mask[self.current_mask] = [0.3, 0.7, 0.3, 0.6]  # Green overlay
            axes[1].imshow(colored_mask)
            
            # Draw mask boundary
            from scipy import ndimage
            try:
                boundary = ndimage.binary_dilation(self.current_mask) ^ self.current_mask
                boundary_mask = np.zeros((*boundary.shape, 4))
                boundary_mask[boundary] = [1, 1, 1, 1]  # White boundary
                axes[1].imshow(boundary_mask)
            except ImportError:
                pass  # scipy not available
        
        axes[1].set_title('Current Mask')
        axes[1].axis('off')
        
        plt.tight_layout()
        plt.show()
    
    def get_stats(self) -> Dict:
        """Get statistics about current selection."""
        stats = {
            'positive_clicks': len(self.positive_points),
            'negative_clicks': len(self.negative_points),
            'mask_area': self.current_mask.sum() if self.current_mask is not None else 0,
            'available_masks': len(self.all_masks),
            'undo_available': len(self.mask_history)
        }
        
        if self.current_mask is not None:
            total_pixels = self.current_mask.size
            stats['mask_percentage'] = 100 * stats['mask_area'] / total_pixels
        
        return stats


print("MagicWand class defined successfully!")
print("\nUsage:")
print("  from segment_anything import sam_model_registry, SamPredictor")
print("  sam = sam_model_registry['vit_b'](checkpoint='sam_vit_b.pth')")
print("  predictor = SamPredictor(sam)")
print("  ")
print("  wand = MagicWand(predictor)")
print("  wand.set_image(image)")
print("  wand.click((100, 150), is_positive=True)")
print("  wand.click((200, 250), is_positive=False)")
print("  wand.visualize()")

## Exercise Solution: Box Prompt Interface

SAM can also segment objects using bounding box prompts.

In [None]:
class BoxSelector:
    """
    Box-based object selection using SAM.
    
    Draw a bounding box around an object, SAM will segment it.
    """
    
    def __init__(self, sam_predictor):
        self.predictor = sam_predictor
        self.image = None
        self.boxes = []
        self.masks = []
    
    def set_image(self, image: np.ndarray):
        """Set image for segmentation."""
        self.predictor.set_image(image)
        self.image = image
        self.boxes = []
        self.masks = []
    
    def add_box(self, box: Tuple[int, int, int, int]) -> np.ndarray:
        """
        Add a bounding box and get segmentation.
        
        Args:
            box: (x1, y1, x2, y2) coordinates
        
        Returns:
            Segmentation mask
        """
        masks, scores, _ = self.predictor.predict(
            point_coords=None,
            point_labels=None,
            box=np.array(box),
            multimask_output=False
        )
        
        self.boxes.append(box)
        self.masks.append(masks[0])
        
        return masks[0]
    
    def visualize(self):
        """Visualize all boxes and their masks."""
        fig, axes = plt.subplots(1, 2, figsize=(14, 6))
        
        # Show boxes on image
        axes[0].imshow(self.image)
        for i, box in enumerate(self.boxes):
            x1, y1, x2, y2 = box
            rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, 
                                  fill=False, edgecolor='red', linewidth=2)
            axes[0].add_patch(rect)
            axes[0].text(x1, y1-5, f'Box {i+1}', color='red', fontsize=10)
        axes[0].set_title('Bounding Boxes')
        axes[0].axis('off')
        
        # Show combined masks
        axes[1].imshow(self.image)
        colors = plt.cm.Set3(np.linspace(0, 1, len(self.masks)))
        
        for mask, color in zip(self.masks, colors):
            colored_mask = np.zeros((*mask.shape, 4))
            colored_mask[mask] = [*color[:3], 0.5]
            axes[1].imshow(colored_mask)
        
        axes[1].set_title('Segmented Objects')
        axes[1].axis('off')
        
        plt.tight_layout()
        plt.show()


print("BoxSelector class defined!")

## Exercise Solution: Automatic Mask Generation

SAM can automatically segment all objects in an image.

In [None]:
def automatic_segmentation(
    sam_model,
    image: np.ndarray,
    points_per_side: int = 32,
    pred_iou_thresh: float = 0.88,
    stability_score_thresh: float = 0.95,
    min_mask_region_area: int = 100
) -> List[Dict]:
    """
    Automatically segment all objects in an image.
    
    Args:
        sam_model: SAM model instance
        image: Input image [H, W, 3]
        points_per_side: Grid density for automatic point sampling
        pred_iou_thresh: Predicted IoU threshold for mask quality
        stability_score_thresh: Stability score threshold
        min_mask_region_area: Minimum mask area in pixels
    
    Returns:
        List of masks with metadata
    """
    try:
        from segment_anything import SamAutomaticMaskGenerator
        
        mask_generator = SamAutomaticMaskGenerator(
            model=sam_model,
            points_per_side=points_per_side,
            pred_iou_thresh=pred_iou_thresh,
            stability_score_thresh=stability_score_thresh,
            min_mask_region_area=min_mask_region_area
        )
        
        masks = mask_generator.generate(image)
        
        print(f"Found {len(masks)} objects")
        
        # Sort by area (largest first)
        masks = sorted(masks, key=lambda x: x['area'], reverse=True)
        
        return masks
        
    except ImportError:
        print("Please install segment-anything: pip install segment-anything")
        return []


def visualize_automatic_masks(image: np.ndarray, masks: List[Dict]):
    """
    Visualize automatically generated masks.
    
    Args:
        image: Original image
        masks: List of mask dictionaries from SamAutomaticMaskGenerator
    """
    plt.figure(figsize=(12, 8))
    plt.imshow(image)
    
    # Generate random colors for each mask
    np.random.seed(42)
    
    for mask_data in masks:
        mask = mask_data['segmentation']
        color = np.random.random(3)
        
        # Create colored overlay
        colored_mask = np.zeros((*mask.shape, 4))
        colored_mask[mask] = [*color, 0.4]
        plt.imshow(colored_mask)
    
    plt.title(f'Automatic Segmentation: {len(masks)} objects found')
    plt.axis('off')
    plt.tight_layout()
    plt.show()
    
    # Print mask statistics
    print("\nMask Statistics:")
    print("="*50)
    print(f"{'#':<5} {'Area':<15} {'IoU Score':<15} {'Stability':<15}")
    print("-"*50)
    for i, m in enumerate(masks[:10]):  # Show top 10
        print(f"{i+1:<5} {m['area']:<15} {m['predicted_iou']:<15.3f} {m['stability_score']:<15.3f}")


print("Automatic segmentation functions defined!")

## Exercise Solution: Mask Refinement

Post-processing techniques to refine SAM masks.

In [None]:
def refine_mask(
    mask: np.ndarray,
    remove_small_regions: bool = True,
    fill_holes: bool = True,
    smooth_boundary: bool = True,
    min_area: int = 500
) -> np.ndarray:
    """
    Refine a binary mask using morphological operations.
    
    Args:
        mask: Binary mask [H, W]
        remove_small_regions: Remove connected components smaller than min_area
        fill_holes: Fill holes in the mask
        smooth_boundary: Apply morphological closing to smooth boundaries
        min_area: Minimum area for region removal
    
    Returns:
        Refined binary mask
    """
    try:
        from scipy import ndimage
        import cv2
    except ImportError:
        print("scipy and cv2 required for mask refinement")
        return mask
    
    refined = mask.copy().astype(np.uint8)
    
    # Remove small connected components
    if remove_small_regions:
        labeled, num_features = ndimage.label(refined)
        for i in range(1, num_features + 1):
            component = labeled == i
            if component.sum() < min_area:
                refined[component] = 0
    
    # Fill holes
    if fill_holes:
        refined = ndimage.binary_fill_holes(refined).astype(np.uint8)
    
    # Smooth boundaries with morphological closing
    if smooth_boundary:
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
        refined = cv2.morphologyEx(refined, cv2.MORPH_CLOSE, kernel)
    
    return refined.astype(bool)


def compare_masks(original: np.ndarray, refined: np.ndarray, image: np.ndarray):
    """
    Compare original and refined masks side by side.
    """
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Original
    axes[0].imshow(image)
    colored = np.zeros((*original.shape, 4))
    colored[original] = [1, 0, 0, 0.5]  # Red
    axes[0].imshow(colored)
    axes[0].set_title(f'Original (area: {original.sum()})')
    axes[0].axis('off')
    
    # Refined
    axes[1].imshow(image)
    colored = np.zeros((*refined.shape, 4))
    colored[refined] = [0, 1, 0, 0.5]  # Green
    axes[1].imshow(colored)
    axes[1].set_title(f'Refined (area: {refined.sum()})')
    axes[1].axis('off')
    
    # Difference
    added = refined & ~original  # New regions
    removed = original & ~refined  # Removed regions
    
    axes[2].imshow(image)
    colored = np.zeros((*original.shape, 4))
    colored[added] = [0, 1, 0, 0.7]  # Green = added
    colored[removed] = [1, 0, 0, 0.7]  # Red = removed
    axes[2].imshow(colored)
    axes[2].set_title('Difference (green=added, red=removed)')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()


print("Mask refinement functions defined!")

## Exercise Solution: SAM with Object Detection

Combining YOLO object detection with SAM segmentation.

In [None]:
def detect_and_segment(
    image: np.ndarray,
    yolo_model,
    sam_predictor,
    classes: Optional[List[str]] = None,
    conf_threshold: float = 0.5
) -> List[Dict]:
    """
    Use YOLO for detection and SAM for segmentation.
    
    Best of both worlds:
    - YOLO: Fast, accurate object detection with class labels
    - SAM: High-quality segmentation masks
    
    Args:
        image: Input image
        yolo_model: YOLO model instance
        sam_predictor: SAM predictor instance
        classes: Optional list of class names to detect
        conf_threshold: Detection confidence threshold
    
    Returns:
        List of dictionaries with class, bbox, confidence, and mask
    """
    # Set image for SAM
    sam_predictor.set_image(image)
    
    # Run YOLO detection
    results = yolo_model(image, conf=conf_threshold)
    
    # Process each detection
    detections = []
    
    for result in results:
        for box in result.boxes:
            # Get detection info
            cls_id = int(box.cls)
            cls_name = yolo_model.names[cls_id]
            
            # Filter by class if specified
            if classes and cls_name not in classes:
                continue
            
            conf = float(box.conf)
            bbox = box.xyxy[0].cpu().numpy()  # [x1, y1, x2, y2]
            
            # Use bbox as SAM prompt
            masks, scores, _ = sam_predictor.predict(
                point_coords=None,
                point_labels=None,
                box=bbox,
                multimask_output=False
            )
            
            detections.append({
                'class': cls_name,
                'confidence': conf,
                'bbox': bbox.tolist(),
                'mask': masks[0],
                'mask_score': float(scores[0])
            })
    
    print(f"Detected and segmented {len(detections)} objects")
    return detections


def visualize_detections(image: np.ndarray, detections: List[Dict]):
    """
    Visualize detection boxes and segmentation masks.
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    # Detection boxes
    axes[0].imshow(image)
    for det in detections:
        x1, y1, x2, y2 = det['bbox']
        rect = plt.Rectangle((x1, y1), x2-x1, y2-y1,
                             fill=False, edgecolor='lime', linewidth=2)
        axes[0].add_patch(rect)
        axes[0].text(x1, y1-5, f"{det['class']} {det['confidence']:.0%}",
                    color='lime', fontsize=10, backgroundcolor='black')
    axes[0].set_title('YOLO Detections')
    axes[0].axis('off')
    
    # Segmentation masks
    axes[1].imshow(image)
    colors = plt.cm.tab10(np.linspace(0, 1, len(detections)))
    
    for det, color in zip(detections, colors):
        mask = det['mask']
        colored_mask = np.zeros((*mask.shape, 4))
        colored_mask[mask] = [*color[:3], 0.5]
        axes[1].imshow(colored_mask)
        
        # Add label at mask centroid
        y_coords, x_coords = np.where(mask)
        if len(x_coords) > 0:
            cx, cy = x_coords.mean(), y_coords.mean()
            axes[1].text(cx, cy, det['class'], color='white', fontsize=10,
                        ha='center', va='center', backgroundcolor='black')
    
    axes[1].set_title('SAM Segmentation Masks')
    axes[1].axis('off')
    
    plt.tight_layout()
    plt.show()


print("YOLO + SAM integration functions defined!")
print("\nUsage:")
print("  from ultralytics import YOLO")
print("  from segment_anything import sam_model_registry, SamPredictor")
print("  ")
print("  yolo = YOLO('yolov8s.pt')")
print("  sam = sam_model_registry['vit_b'](checkpoint='sam_vit_b.pth')")
print("  predictor = SamPredictor(sam)")
print("  ")
print("  detections = detect_and_segment(image, yolo, predictor)")
print("  visualize_detections(image, detections)")

## Summary

Key concepts covered:

1. **Magic Wand Tool**: Interactive point-based segmentation with:
   - Positive/negative clicks
   - Grow/shrink mask options
   - Undo functionality

2. **Box Prompts**: Bounding box-based segmentation

3. **Automatic Segmentation**: Segment all objects without prompts

4. **Mask Refinement**: Post-processing with morphological operations

5. **YOLO + SAM**: Combining detection and segmentation

SAM provides state-of-the-art segmentation that works with various input prompts (points, boxes, masks).

In [None]:
# Cleanup
import gc
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
print("Cleanup complete!")