In [1]:
# Install required packages
!pip install -q torch transformers supervision
!pip install -q git+https://github.com/facebookresearch/segment-anything-2.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [None]:
# CELL 1 (MODIFIED)
print("Installing packages...")
!pip install -q torch transformers supervision
!pip install -q git+https://github.com/facebookresearch/segment-anything-2.git
print("Installation complete. Forcing kernel restart...")

# This line will crash and restart the kernel automatically
import os
os.kill(os.getpid(), 9)

Installing packages...
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [None]:
# This line will crash and restart the kernel automatically
import os
os.kill(os.getpid(), 9)

In [11]:
import torch
import requests
from PIL import Image
import supervision as sv
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection

# Corrected imports to use the actual package name
from segment_anything_2.build_sam import build_sam2
from segment_anything_2.predictor import Sam2Predictor

# Configure device (use GPU if available)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Configure annotators for visualization
box_annotator = sv.BoundingBoxAnnotator()
mask_annotator = sv.MaskAnnotator()

ModuleNotFoundError: No module named 'segment_anything_2'

In [None]:
# --- Load GroundingDINO Model ---
grounding_dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
grounding_dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to(DEVICE)

# --- Load SAM 2 Model ---
# Note: This might take a moment as it downloads the model checkpoint
sam2_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything_2/032424/sam2_hiera_base_plus.pth"
sam2_model = build_sam2(
    model_id="sam2_hiera_b+",
    image_size=1024, # The image size the model was trained on
    checkpoint_url=sam2_checkpoint_url
).to(DEVICE)

sam2_predictor = Sam2Predictor(sam2_model)

print("Models loaded successfully! ✅")

In [None]:
# 1. Define Inputs
IMAGE_URL = "/kaggle/input/babyyy/ChatGPT Image Aug 22 2025 09_44_07 PM.png"
TEXT_PROMPT = "a baby"
BOX_THRESHOLD = 0.35 # Confidence threshold for detected boxes

# Load the image from the URL
response = requests.get(IMAGE_URL, stream=True)
image_pil = Image.open(response.raw).convert("RGB")

# 2. Convert Text to Region Seeds (via GroundingDINO)
# Pre-process the image and text
inputs = grounding_dino_processor(images=image_pil, text=TEXT_PROMPT, return_tensors="pt").to(DEVICE)

# Run inference
with torch.no_grad():
    outputs = grounding_dino_model(**inputs)

# Post-process the results to get bounding boxes
results = grounding_dino_processor.post_process_grounded_object_detection(
    outputs,
    inputs.input_ids,
    box_threshold=BOX_THRESHOLD,
    text_threshold=BOX_THRESHOLD,
    target_sizes=[image_pil.size[::-1]]
)

# Extract detected boxes
# The output is a list of tuples, one for each image in the batch
detections = sv.Detections.from_transformers(results[0])
print(f"Found {len(detections)} boxes for the prompt '{TEXT_PROMPT}'")

# 3. Feed Seeds to SAM 2
# Set the image for the SAM 2 predictor
sam2_predictor.set_image(image_pil)

# Convert bounding boxes to the format required by SAM 2
input_boxes = detections.xyxy

# Get segmentation masks from SAM 2
# The model returns masks, quality scores, and low-res logits
masks, scores, logits = sam2_predictor.predict(
    box=input_boxes,
    multimask_output=False # We want one high-quality mask per box
)

# Add masks to our supervision Detections object
detections.mask = masks.cpu().numpy()

# 4. Display the Final Mask Overlay
# Annotate the image with both boxes and masks
annotated_image = box_annotator.annotate(scene=image_pil.copy(), detections=detections)
annotated_image = mask_annotator.annotate(scene=annotated_image, detections=detections)

print("\nDisplaying final result... 🎨")
sv.plot_image(annotated_image, size=(8, 8))