In [None]:
# Full Text-Driven Segmentation: CLIPSeg + SAM2
# ==========================

# ⿡ Install dependencies (run in terminal)
!pip install torch torchvision transformers matplotlib yacs timm opencv-python Pillow


In [4]:
# ⿢ Imports
import torch
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation



In [None]:
# SAM2 imports
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor



In [None]:
# ⿣ Settings
# ==========================
device = "cuda" if torch.cuda.is_available() else "cpu"

# Paths
sam2_cfg = "configs/sam2/sam2_hiera_base.yaml"  # Relative path inside SAM2 repo
sam2_ckpt = "sam2_hiera_base.pt"               # SAM2 checkpoint downloaded locally
image_path = "my_image.jpg"                    # Replace with your local image
text_prompt = "dog"                             # Replace with desired object


In [None]:
# ⿤ Load image
# ==========================
image = Image.open(image_path).convert("RGB")
image_np = np.array(image)

plt.figure(figsize=(6,6))
plt.imshow(image)
plt.axis("off")
plt.title("Original Image")
plt.show()


In [None]:
# ⿥ CLIPSeg: text → coarse mask
# ==========================
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
clipseg_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device)

inputs = processor(text=[text_prompt], images=image, return_tensors="pt").to(device)
with torch.no_grad():
    outputs = clipseg_model(**inputs)

mask_coarse = torch.sigmoid(outputs.logits).squeeze().cpu().numpy()

# Visualize CLIPSeg mask
plt.figure(figsize=(6,6))
plt.imshow(image)
plt.imshow(mask_coarse, alpha=0.5, cmap="Reds")
plt.axis("off")
plt.title(f"CLIPSeg Mask: {text_prompt}")
plt.show()



In [None]:
# ⿦ Convert CLIPSeg mask → bounding box
# ==========================
y, x = np.where(mask_coarse > 0.5)
if len(x) == 0 or len(y) == 0:
    input_box = np.array([0, 0, image_np.shape[1]-1, image_np.shape[0]-1])
else:
    input_box = np.array([min(x), min(y), max(x), max(y)])

print("Bounding box from CLIPSeg mask:", input_box)



In [None]:
# ⿧ SAM2: refined segmentation
# ==========================
sam2_model = build_sam2(sam2_cfg, sam2_ckpt, device=device)
predictor = SAM2ImagePredictor(sam2_model)
predictor.set_image(image_np)

# Predict refined mask using bounding box
sam_mask, _, _ = predictor.predict(box=input_box[None, :])
mask_refined = sam_mask[0][0]

# Visualize SAM2 refined mask
plt.figure(figsize=(6,6))
plt.imshow(image_np)
plt.imshow(mask_refined, alpha=0.5, cmap="Reds")
plt.axis("off")
plt.title(f"SAM2 Refined Mask: {text_prompt}")
plt.show()