In [5]:
!pip install -q transformers==4.56.1 accelerate timm opencv-python pycocotools

!pip install -q git+https://github.com/facebookresearch/segment-anything-2.git

!pip install -q hydra-core==1.3.2 omegaconf==2.3.0

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m6.0 kB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.6/11.6 MB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.5/154.5 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for SAM-2 (pyproject.toml) ... [?25l[?25hdone
  Building wheel for iopath (setup.py) ... [?25l[?25hdone


In [1]:
import torch
from PIL import Image
import numpy as np
import requests

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [2]:
!wget https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt -P /content/drive/MyDrive/models

--2025-10-04 17:41:48--  https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 3.165.75.95, 3.165.75.66, 3.165.75.59, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|3.165.75.95|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 897952466 (856M) [application/vnd.snesdev-page-table]
Saving to: ‘/content/drive/MyDrive/models/sam2_hiera_large.pt’


2025-10-04 17:41:54 (158 MB/s) - ‘/content/drive/MyDrive/models/sam2_hiera_large.pt’ saved [897952466/897952466]



In [3]:
!wget https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2/sam2_hiera_l.yaml -P /content/models


--2025-10-04 17:41:54--  https://raw.githubusercontent.com/facebookresearch/sam2/main/sam2/configs/sam2/sam2_hiera_l.yaml
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3696 (3.6K) [text/plain]
Saving to: ‘/content/models/sam2_hiera_l.yaml’


2025-10-04 17:41:54 (64.8 MB/s) - ‘/content/models/sam2_hiera_l.yaml’ saved [3696/3696]



In [6]:
import os
import torch
from PIL import Image
import numpy as np

from transformers import (
    AutoProcessor, AutoModelForZeroShotObjectDetection,
    OwlViTProcessor, OwlViTForObjectDetection,
    CLIPProcessor, CLIPModel
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#Grounding DINO (primary detector)
DINO_ID = "IDEA-Research/grounding-dino-base"
dino_processor = AutoProcessor.from_pretrained(DINO_ID)
dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(DINO_ID).to(device).eval()

#OWL-ViT (fallback detector)
OWL_ID = "google/owlvit-base-patch32"
owl_processor = OwlViTProcessor.from_pretrained(OWL_ID)
owl_model = OwlViTForObjectDetection.from_pretrained(OWL_ID).to(device).eval()

#CLIP (re-ranker)
CLIP_ID = "openai/clip-vit-base-patch32"
clip_processor = CLIPProcessor.from_pretrained(CLIP_ID)
clip_model = CLIPModel.from_pretrained(CLIP_ID).to(device).eval()

#SAM 2 (segmentation)
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

# Set these to your uploaded files / paths:
SAM2_CKPT = "/content/drive/MyDrive/models/sam2_hiera_large.pt"
SAM2_CFG  = "sam2_hiera_l.yaml"

assert os.path.exists(SAM2_CKPT), f"SAM2 checkpoint not found at {SAM2_CKPT}"

# Build SAM 2 model. IMPORTANT: pass config NAME, not a file path.
_sam2_model = build_sam2(SAM2_CFG, ckpt_path=SAM2_CKPT, device=device, mode="eval")
sam2_image  = SAM2ImagePredictor(_sam2_model)

print("Models ready (DINO, OWL-ViT, CLIP, SAM 2).")


Models ready (DINO, OWL-ViT, CLIP, SAM 2).


In [7]:
from typing import List, Tuple

def detect_objects(image_pil: Image.Image,
                   text_prompt: str,
                   dino_score_thresh: float = 0.30,
                   owl_score_thresh: float = 0.30) -> List[Tuple[list, float]]:
    """
    Returns list of (box_xyxy, score) in pixel coordinates.
    Grounding DINO first; if none above threshold, fallback to OWL-ViT.
    """
    W, H = image_pil.size

    #Grounding DINO
    query = text_prompt.strip().lower()
    if not query.endswith("."):
        query += "."

    dinputs = dino_processor(images=image_pil, text=[query], return_tensors="pt").to(device)
    with torch.no_grad():
        douts = dino_model(**dinputs)

    # transformers 4.56.1: only pass outputs + target_sizes (+ input_ids to map labels)
    dres = dino_processor.post_process_grounded_object_detection(
        outputs=douts,
        input_ids=dinputs.input_ids,             # lets HF derive text_labels
        target_sizes=[(H, W)]                    # pixel boxes
    )

    dets = []
    if dres and len(dres) > 0:
        boxes = dres[0].get("boxes", [])
        scores = dres[0].get("scores", [])
        for b, s in zip(boxes, scores):
            s = float(s)
            if s >= dino_score_thresh:
                dets.append(([float(x) for x in b.tolist()], s))

    #Fallback: OWL-ViT
    if len(dets) == 0:
        oinputs = owl_processor(text=[[text_prompt]], images=image_pil, return_tensors="pt").to(device)
        with torch.no_grad():
            oouts = owl_model(**oinputs)

        target = torch.tensor([(H, W)], device=device)
        ores = owl_processor.post_process(outputs=oouts, target_sizes=target)
        if ores and len(ores) > 0:
            oboxes = ores[0].get("boxes", [])
            oscores = ores[0].get("scores", [])
            for b, s in zip(oboxes, oscores):
                s = float(s)
                if s >= owl_score_thresh:
                    dets.append(([float(x) for x in b.tolist()], s))

    return dets


In [8]:
def filter_detections_by_clip(image_pil: Image.Image,
                              detections: list,
                              text_prompt: str,
                              keep_ratio: float = 0.90):
    """
    detections: list of (box_xyxy, score). Returns list of boxes (xyxy) kept by CLIP re-rank.
    """
    if not detections:
        return []

    # crop each box
    crops = []
    for (box, _score) in detections:
        x0, y0, x1, y1 = map(int, box)
        x0 = max(0, x0); y0 = max(0, y0)
        x1 = min(image_pil.width, x1); y1 = min(image_pil.height, y1)
        crops.append(image_pil.crop((x0, y0, x1, y1)))

    with torch.no_grad():
        t = clip_processor(text=[text_prompt], return_tensors="pt", padding=True).to(device)
        i = clip_processor(images=crops, return_tensors="pt").to(device)
        tfeat = clip_model.get_text_features(**t)           # (1, d)
        ifeat = clip_model.get_image_features(**i)          # (N, d)
        tfeat = torch.nn.functional.normalize(tfeat, dim=-1)
        ifeat = torch.nn.functional.normalize(ifeat, dim=-1)
        sims = (ifeat @ tfeat.T).squeeze(-1).detach().cpu().numpy()  # (N,)

    top = float(np.max(sims))
    keep_idx = [k for k, s in enumerate(sims) if s >= keep_ratio * top]
    return [detections[k][0] for k in keep_idx]  # just boxes


In [9]:
def segment_masks(image_pil: Image.Image, boxes_xyxy: list):
    """
    SAM 2 segmentation for a list of boxes.
    Returns: list of boolean masks (H, W)
    """
    img = np.array(image_pil)
    sam2_image.set_image(img)

    H, W = img.shape[:2]
    masks = []
    for box in boxes_xyxy:
        x0, y0, x1, y1 = [float(v) for v in box]
        x0 = max(0, x0); y0 = max(0, y0)
        x1 = min(W - 1, x1); y1 = min(H - 1, y1)
        box_np = np.array([x0, y0, x1, y1], dtype=np.float32)

        # SAM 2 image predictor; returns (masks, scores, logits/extra)
        m, scores, _ = sam2_image.predict(box=box_np, multimask_output=False)
        masks.append(m[0].astype(bool))
    return masks


In [10]:
def segment_object_in_image(image_pil, text_prompt):
    # 1) detect
    detections = detect_objects(image_pil, text_prompt)
    if not detections:
        print(f"No '{text_prompt}' found in the image.")
        return None, None

    # 2) CLIP re-rank
    selected_boxes = filter_detections_by_clip(image_pil, detections, text_prompt)
    if not selected_boxes:
        print(f"CLIP filtering removed all detections for '{text_prompt}'.")
        return None, None

    # 3) SAM 2 segmentation
    masks = segment_masks(image_pil, selected_boxes)

    # 4) combine + apply mask
    final_mask = np.zeros((image_pil.height, image_pil.width), dtype=bool)
    for m in masks:
        final_mask |= m

    image_np = np.array(image_pil)
    output = np.zeros_like(image_np)
    output[final_mask] = image_np[final_mask]
    return Image.fromarray(output), final_mask


In [11]:
import requests
from PIL import Image

image_url = "https://i.ytimg.com/vi/xH5DFu_eLUY/maxresdefault.jpg"
prompt = "tree"

image_pil = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
result_image, mask = segment_object_in_image(image_pil, prompt)

if result_image is not None:
    result_image.save("output.png")
    print("Segmentation completed. Saved: output.png")
else:
    print("No result.")


Segmentation completed. Saved: output.png
