# Q2: Text-Driven Image Segmentation with SAM 2

Implementation of text-prompted image segmentation using SAM 2 (Segment Anything Model 2) with GroundingDINO for text-to-region conversion.

## Setup and Dependencies

In [None]:
# Install required packages
!pip install torch torchvision torchaudio
!pip install transformers
!pip install opencv-python
!pip install matplotlib
!pip install pillow
!pip install numpy
!pip install requests

# Install SAM 2
!pip install git+https://github.com/facebookresearch/segment-anything-2.git

# Install GroundingDINO
!pip install git+https://github.com/IDEA-Research/GroundingDINO.git

# Alternative: Install supervision for additional utilities
!pip install supervision

import torch
import torch.nn.functional as F
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import requests
from io import BytesIO
import warnings
warnings.filterwarnings('ignore')

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## Load Models

In [None]:
# Load SAM 2
try:
    from sam2.build_sam import build_sam2
    from sam2.sam2_image_predictor import SAM2ImagePredictor
    
    # Download SAM 2 checkpoint
    checkpoint = "./sam2_hiera_large.pt"
    model_cfg = "sam2_hiera_l.yaml"
    
    # Try to download if not exists
    import os
    if not os.path.exists(checkpoint):
        !wget https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt
    
    sam2_model = build_sam2(model_cfg, checkpoint, device=device)
    predictor = SAM2ImagePredictor(sam2_model)
    print("SAM 2 loaded successfully!")
    
except Exception as e:
    print(f"Error loading SAM 2: {e}")
    print("Falling back to alternative implementation...")
    
    # Alternative: Use transformers SAM model
    from transformers import SamModel, SamProcessor
    
    model_name = "facebook/sam-vit-huge"
    sam_model = SamModel.from_pretrained(model_name).to(device)
    sam_processor = SamProcessor.from_pretrained(model_name)
    print("Alternative SAM model loaded!")

In [None]:
# Load GroundingDINO for text-to-bbox conversion
try:
    from groundingdino.models import build_model
    from groundingdino.util.slconfig import SLConfig
    from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
    from groundingdino.util.inference import annotate, load_model, predict
    
    # Load GroundingDINO
    grounding_model = load_model("groundingdino/config/GroundingDINO_SwinT_OGC.py", 
                                "weights/groundingdino_swint_ogc.pth")
    print("GroundingDINO loaded successfully!")
    
except Exception as e:
    print(f"Error loading GroundingDINO: {e}")
    print("Will use alternative text-to-region method...")
    grounding_model = None

## Alternative Text-to-Region Implementation

In [None]:
# Alternative implementation using CLIP for text-to-region
!pip install clip-by-openai

import clip
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize

# Load CLIP model
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
print("CLIP model loaded for text understanding!")

def sliding_window_clip(image, text_prompt, window_size=224, stride=112, threshold=0.25):
    """Use sliding window with CLIP to find regions matching text prompt."""
    h, w = image.shape[:2]
    best_boxes = []
    best_scores = []
    
    # Convert to PIL for CLIP preprocessing
    pil_image = Image.fromarray(image)
    
    # Tokenize text
    text_tokens = clip.tokenize([text_prompt]).to(device)
    
    with torch.no_grad():
        text_features = clip_model.encode_text(text_tokens)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    
    # Sliding window
    for y in range(0, h - window_size + 1, stride):
        for x in range(0, w - window_size + 1, stride):
            # Extract window
            window = pil_image.crop((x, y, x + window_size, y + window_size))
            
            # Preprocess for CLIP
            window_tensor = clip_preprocess(window).unsqueeze(0).to(device)
            
            # Get image features
            with torch.no_grad():
                image_features = clip_model.encode_image(window_tensor)
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)
                
                # Calculate similarity
                similarity = (image_features @ text_features.T).item()
            
            if similarity > threshold:
                best_boxes.append([x, y, x + window_size, y + window_size])
                best_scores.append(similarity)
    
    # Non-maximum suppression (simple version)
    if best_boxes:
        # Sort by score
        sorted_indices = sorted(range(len(best_scores)), key=lambda i: best_scores[i], reverse=True)
        
        # Take top box
        best_idx = sorted_indices[0]
        return [best_boxes[best_idx]], [best_scores[best_idx]]
    
    return [], []

## Text-to-Segmentation Pipeline

In [None]:
def text_to_segmentation_pipeline(image_path, text_prompt, use_url=False):
    """Complete pipeline from text prompt to segmentation mask."""
    
    # Load image
    if use_url:
        response = requests.get(image_path)
        image = Image.open(BytesIO(response.content)).convert('RGB')
        image_np = np.array(image)
    else:
        image = Image.open(image_path).convert('RGB')
        image_np = np.array(image)
    
    print(f"Image loaded: {image_np.shape}")
    print(f"Text prompt: '{text_prompt}'")
    
    # Step 1: Text to bounding boxes
    if grounding_model is not None:
        # Use GroundingDINO if available
        try:
            boxes, logits, phrases = predict(
                model=grounding_model, 
                image=image_np, 
                caption=text_prompt,
                box_threshold=0.3,
                text_threshold=0.25
            )
            
            # Convert to pixel coordinates
            h, w = image_np.shape[:2]
            boxes_pixel = boxes * torch.tensor([w, h, w, h])
            boxes_pixel = boxes_pixel.cpu().numpy()
            
        except Exception as e:
            print(f"GroundingDINO failed: {e}, using CLIP fallback")
            boxes_pixel, scores = sliding_window_clip(image_np, text_prompt)
            boxes_pixel = np.array(boxes_pixel) if boxes_pixel else np.array([])
    else:
        # Use CLIP sliding window
        boxes_pixel, scores = sliding_window_clip(image_np, text_prompt)
        boxes_pixel = np.array(boxes_pixel) if boxes_pixel else np.array([])
    
    if len(boxes_pixel) == 0:
        print("No regions found for the text prompt. Using center point.")
        h, w = image_np.shape[:2]
        # Use center as fallback
        input_points = np.array([[w//2, h//2]])
        input_labels = np.array([1])
    else:
        # Convert boxes to points (center of first box)
        box = boxes_pixel[0]
        center_x = (box[0] + box[2]) / 2
        center_y = (box[1] + box[3]) / 2
        input_points = np.array([[center_x, center_y]])
        input_labels = np.array([1])
        
        print(f"Found {len(boxes_pixel)} regions, using center point: ({center_x:.1f}, {center_y:.1f})")
    
    # Step 2: SAM segmentation
    try:
        # Try SAM 2 first
        predictor.set_image(image_np)
        masks, scores, logits = predictor.predict(
            point_coords=input_points,
            point_labels=input_labels,
            multimask_output=True,
        )
        
        # Select best mask
        best_mask_idx = np.argmax(scores)
        best_mask = masks[best_mask_idx]
        
    except Exception as e:
        print(f"SAM 2 failed: {e}, using alternative SAM")
        
        # Use transformers SAM
        inputs = sam_processor(image, input_points=[input_points.tolist()], return_tensors="pt").to(device)
        
        with torch.no_grad():
            outputs = sam_model(**inputs)
        
        masks = sam_processor.image_processor.post_process_masks(
            outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
        )
        best_mask = masks[0][0][0].numpy()
    
    return image_np, best_mask, input_points

def visualize_segmentation(image, mask, points, text_prompt):
    """Visualize the segmentation result."""
    plt.figure(figsize=(15, 5))
    
    # Original image
    plt.subplot(1, 3, 1)
    plt.imshow(image)
    plt.title('Original Image')
    plt.axis('off')
    
    # Mask
    plt.subplot(1, 3, 2)
    plt.imshow(mask, cmap='gray')
    plt.title('Segmentation Mask')
    plt.axis('off')
    
    # Overlay
    plt.subplot(1, 3, 3)
    plt.imshow(image)
    
    # Create colored mask overlay
    colored_mask = np.zeros_like(image)
    colored_mask[:, :, 0] = mask * 255  # Red channel
    
    # Blend with original image
    overlay = image.copy()
    overlay[mask > 0] = overlay[mask > 0] * 0.6 + colored_mask[mask > 0] * 0.4
    
    plt.imshow(overlay.astype(np.uint8))
    
    # Plot input points
    for point in points:
        plt.plot(point[0], point[1], 'go', markersize=10, markeredgewidth=2, markeredgecolor='white')
    
    plt.title(f'Segmentation: "{text_prompt}"')
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()

## Example 1: Segment a Dog

In [None]:
# Download a sample image
image_url = "https://images.unsplash.com/photo-1552053831-71594a27632d?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1000&q=80"
text_prompt = "dog"

# Run the pipeline
image, mask, points = text_to_segmentation_pipeline(image_url, text_prompt, use_url=True)

# Visualize results
visualize_segmentation(image, mask, points, text_prompt)

## Example 2: Segment a Car

In [None]:
# Another example with a car
image_url = "https://images.unsplash.com/photo-1549317661-bd32c8ce0db2?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1000&q=80"
text_prompt = "car"

# Run the pipeline
image, mask, points = text_to_segmentation_pipeline(image_url, text_prompt, use_url=True)

# Visualize results
visualize_segmentation(image, mask, points, text_prompt)

## Example 3: Custom Image Upload

In [None]:
# Interactive example - you can upload your own image
from google.colab import files
import os

print("Upload an image file:")
uploaded = files.upload()

# Get the uploaded filename
filename = list(uploaded.keys())[0]
print(f"Uploaded: {filename}")

# Text prompt input
text_prompt = input("Enter text prompt for segmentation (e.g., 'person', 'cat', 'tree'): ")

# Run the pipeline
image, mask, points = text_to_segmentation_pipeline(filename, text_prompt, use_url=False)

# Visualize results
visualize_segmentation(image, mask, points, text_prompt)

# Clean up
os.remove(filename)

## Bonus: Video Segmentation with SAM 2

In [None]:
# Video segmentation implementation
try:
    from sam2.build_sam import build_sam2_video_predictor
    
    # Build video predictor
    video_predictor = build_sam2_video_predictor("sam2_hiera_l.yaml", "./sam2_hiera_large.pt")
    
    def video_segmentation_pipeline(video_path, text_prompt, frame_interval=5):
        """Segment object in video using text prompt."""
        
        # Extract frames from video
        cap = cv2.VideoCapture(video_path)
        frames = []
        frame_count = 0
        
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            
            if frame_count % frame_interval == 0:
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames.append(frame_rgb)
            
            frame_count += 1
            
            # Limit to 30 frames for demo
            if len(frames) >= 30:
                break
        
        cap.release()
        
        if not frames:
            print("No frames extracted from video")
            return
        
        print(f"Extracted {len(frames)} frames")
        
        # Initialize video predictor
        inference_state = video_predictor.init_state(video_path=None, frames=frames)
        
        # Get initial segmentation from first frame
        first_frame = frames[0]
        
        # Use text-to-region on first frame
        boxes_pixel, scores = sliding_window_clip(first_frame, text_prompt)
        
        if boxes_pixel:
            box = boxes_pixel[0]
            center_x = (box[0] + box[2]) / 2
            center_y = (box[1] + box[3]) / 2
            input_points = np.array([[center_x, center_y]])
            input_labels = np.array([1])
        else:
            # Use center as fallback
            h, w = first_frame.shape[:2]
            input_points = np.array([[w//2, h//2]])
            input_labels = np.array([1])
        
        # Add points to first frame
        frame_idx, object_ids, masks = video_predictor.add_new_points(
            inference_state=inference_state,
            frame_idx=0,
            obj_id=1,
            points=input_points,
            labels=input_labels,
        )
        
        # Propagate masks through video
        video_segments = {}
        for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
            video_segments[out_frame_idx] = {
                out_obj_ids[0]: (out_mask_logits[0] > 0.0).cpu().numpy()
            }
        
        return frames, video_segments
    
    def visualize_video_segmentation(frames, video_segments, text_prompt, max_frames=6):
        """Visualize video segmentation results."""
        n_frames = min(len(frames), max_frames)
        
        plt.figure(figsize=(20, 8))
        
        for i in range(n_frames):
            frame = frames[i]
            mask = video_segments.get(i, {}).get(1, np.zeros(frame.shape[:2], dtype=bool))
            
            # Create overlay
            overlay = frame.copy()
            if mask.any():
                overlay[mask] = overlay[mask] * 0.6 + np.array([255, 0, 0]) * 0.4
            
            plt.subplot(2, n_frames//2, i+1)
            plt.imshow(overlay.astype(np.uint8))
            plt.title(f'Frame {i}: "{text_prompt}"')
            plt.axis('off')
        
        plt.tight_layout()
        plt.show()
    
    print("Video segmentation functions loaded successfully!")
    
except Exception as e:
    print(f"Video segmentation not available: {e}")
    print("This requires SAM 2 video predictor which may not be fully available in this environment.")

## Video Segmentation Example

In [None]:
# Example video segmentation (if available)
try:
    print("Upload a short video file (MP4 format, <30 seconds):")
    uploaded_video = files.upload()
    
    video_filename = list(uploaded_video.keys())[0]
    video_text_prompt = input("Enter text prompt for video segmentation: ")
    
    # Run video segmentation
    frames, video_segments = video_segmentation_pipeline(video_filename, video_text_prompt)
    
    # Visualize results
    visualize_video_segmentation(frames, video_segments, video_text_prompt)
    
    # Clean up
    os.remove(video_filename)
    
except Exception as e:
    print(f"Video segmentation example skipped: {e}")
    print("This feature requires video upload and may not work in all environments.")

## Pipeline Summary and Analysis

### Pipeline Description:
1. **Text Input**: User provides a text description of the object to segment
2. **Text-to-Region**: Convert text to potential object locations using:
   - Primary: GroundingDINO (if available)
   - Fallback: CLIP with sliding window approach
3. **Region-to-Segmentation**: Use SAM 2 to generate precise segmentation masks
4. **Visualization**: Display original image, mask, and overlay

### Technical Implementation:
- **GroundingDINO**: State-of-the-art open-vocabulary object detection
- **CLIP Fallback**: Sliding window approach with vision-language similarity
- **SAM 2**: Latest Segment Anything model for high-quality segmentation
- **Robust Error Handling**: Multiple fallback mechanisms for reliability

### Limitations:
1. **Text Ambiguity**: Simple text prompts may not capture complex spatial relationships
2. **Single Object Focus**: Currently optimized for single primary object per prompt
3. **Computational Requirements**: Requires GPU for optimal performance
4. **Model Dependencies**: Relies on large pre-trained models (>1GB each)
5. **Video Processing**: Video segmentation requires significant computational resources

### Potential Improvements:
1. **Multi-object Support**: Handle multiple objects in single prompt
2. **Spatial Relationships**: "dog to the left of the tree"
3. **Interactive Refinement**: Allow user to refine segmentation with additional clicks
4. **Real-time Processing**: Optimize for live video segmentation
5. **Custom Training**: Fine-tune on domain-specific data