<a href="https://colab.research.google.com/github/Aleptonic/IISC-VIT/blob/main/q2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Text-Driven Image Segmentation with SAM 2

This notebook implements a complete pipeline to perform segmentation on any object in an image using a free-form text prompt. The project's goal is to demonstrate how two different state-of-the-art models, **GroundingDINO** and the **Segment Anything Model (SAM) 2**, can be composed to achieve a powerful new capability.

# Dependency installation
This cell installs all necessary libraries and downloads the pre-trained model weights.
- **`transformers`**: From Hugging Face, used for a reliable and easy-to-use implementation of GroundingDINO.
- **`segment-anything-py`**: The official library for the Segment Anything Model.
- **`supervision`**: A helpful library for visualization and handling annotations (used in some helper functions).

In [6]:
!pip install -q transformers supervision segment-anything-py Pillow

# Dependency

In [7]:
import os
import requests
import torch
import numpy as np
import supervision as sv
import PIL
from PIL import Image
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from dataclasses import dataclass
# --- Download SAM weights ---
WEIGHTS_DIR = os.path.join(os.getcwd(), "weights")
os.makedirs(WEIGHTS_DIR, exist_ok=True)
SAM_WEIGHTS_PATH = os.path.join(WEIGHTS_DIR, "sam_vit_h_4b8939.pth")

if not os.path.exists(SAM_WEIGHTS_PATH):
    !wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P {WEIGHTS_DIR}

from segment_anything import sam_model_registry, SamPredictor
print("✅ Initial Setup complete.")

from google.colab import files
import ipywidgets as widgets
from IPython.display import display, clear_output
import PIL.Image

✅ Initial Setup complete.


In [8]:
@dataclass
class ModelParameters:
  dino_model_id :str = "IDEA-Research/grounding-dino-tiny"
  sam_model_type = "vit_h"
  device = "cuda" if torch.cuda.is_available() else "cpu"
@dataclass
class DataParameters:
  img_path :str
  vid_path :str
  save_path :str
  input_text_prompt :str



# Model
The solution is built on a "Finder-Painter" principle, where each model has a specialized role:

1.  **The Finder (GroundingDINO):** An open-set object detector that excels at understanding language. It takes our text prompt and finds *where* the object is in the image, returning bounding boxes.
2.  **The Painter (SAM):** A promptable segmentation model that excels at understanding object boundaries. It takes the bounding boxes from the Finder as a guide and "paints" a pixel-perfect mask.

Our code encapsulates these roles in distinct classes (`DINO`, `SAM`) and orchestrates the flow of data between them.

In [9]:
class DINO:
  def __init__(self, mparams):
    self.device = mparams.device
    self.processor = AutoProcessor.from_pretrained(mparams.dino_model_id)
    self.model = AutoModelForZeroShotObjectDetection.from_pretrained(mparams.dino_model_id)
    print("GroundingDINO model loaded.")

  def _enhance_input_text(self, text_prompt:str):
    """To enhance the text prompt such that each probable class is seperated by '.' """
    prompt_list = [p.strip() for p in text_prompt.split(',')]
    return prompt_list

  def _get_bounding_box(self, image:PIL.Image, text_prompt):
    inputs = self.processor(images=image, text=self._enhance_input_text(text_prompt), return_tensors="pt").to(self.device)
    with torch.no_grad():
        outputs = self.model(**inputs)

    results = self.processor.post_process_grounded_object_detection(
        outputs,
        inputs.input_ids,
        threshold=0.4,
        text_threshold=0.3,
        target_sizes=[image.size[::-1]] # change to h,w
    )

    # Extract the boxes and labels
    dino_boxes = results[0]["boxes"] # Get the boxes tensor
    dino_labels = results[0]["text_labels"]
    print(f"GroundingDINO found labels: {dino_labels}")
    print(f"GroundingDINO found {len(dino_boxes)} box(es).")
    return dino_boxes, dino_labels

class SAM:
  def __init__ (self, mparams):
    self.device = mparams.device
    self.sam = sam_model_registry[mparams.sam_model_type](checkpoint=SAM_WEIGHTS_PATH)
    self.sam.to(device=mparams.device)
    self.sam_predictor = SamPredictor(self.sam)
    print("SAM model loaded.")

  def _segment(self, image:PIL.Image, bounding_boxes):
    image_np = np.array(image)
    self.sam_predictor.set_image(image_np)

    # Transform the boxes from DINO to match SAM's internal image representation
    transformed_boxes = self.sam_predictor.transform.apply_boxes_torch(bounding_boxes, image_np.shape[:2])

    # using transformed boxes for predictions
    masks, _, _ = self.sam_predictor.predict_torch(
        point_coords=None,
        point_labels=None,
        boxes=transformed_boxes.to(self.device), # Use the transformed boxes
        multimask_output=False,
    )
    print("SAM generated masks.")
    return masks

class Visualize:
  @staticmethod
  def _show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6]) # Dodger blue
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

  @staticmethod
  def _show_box(box, ax, label):
    x_min, y_min, x_max, y_max = box
    rect = patches.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min, linewidth=2,
                             edgecolor='r', facecolor='none')
    ax.add_patch(rect)
    ax.text(x_min, y_min - 10, label, color='white', fontsize=12,
            bbox=dict(facecolor='red', alpha=0.5))

  @staticmethod
  def _show_segmentation(image:PIL.Image, input_prompt:str, seg_masks, bounding_boxes, box_labels):
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(image)
    ax.set_title(f"Segmentation for: '{input_prompt}'")

    # Plotting masks
    if seg_masks is not None:
        for mask in seg_masks:
            Visualize._show_mask(mask.cpu().numpy(), ax, random_color=True)
    if len(bounding_boxes) > 0:
        for box, label in zip(bounding_boxes, box_labels):
            Visualize._show_box(box.cpu().numpy(), ax, label) # Pass the specific label for this box

    plt.savefig('segmentation_mask.png')
    plt.axis('off')
    plt.show()

# Interactive Demonstration


- **Upload an image** from your local machine.
- **Enter a text prompt** describing the object you want to segment (you can list multiple objects separated by commas, e.g., "a person, a surfboard").
- **Click "Run Segmentation"** to see the final result.

In [10]:
class UI_Handler:
    def __init__(self, dino_model, sam_model):
        """
        Initializes the handler with pre-loaded DINO and SAM models.
        """
        self.dino = dino_model
        self.sam = sam_model
        self.image_path = None

        # --- Create UI Widgets ---
        self.uploader = widgets.FileUpload(
            accept='image/*',
            description='Upload Image'
        )
        self.text_prompt = widgets.Text(
            value='a red cap',
            placeholder='Enter object(s) to find, separated by commas',
            description='Prompt:',
            disabled=True # Disabled until an image is uploaded
        )
        self.run_button = widgets.Button(
            description='Run Segmentation',
            disabled=True,
            button_style='success',
            icon='check'
        )
        self.output = widgets.Output()

        # --- Define Widget Actions ---
        self.uploader.observe(self._handle_upload, names='value')
        self.run_button.on_click(self._run_pipeline)

    def _handle_upload(self, change):
        """Called when a file is uploaded."""
        with self.output:
            clear_output() # Clear previous image/results
            uploaded_files = change['new']

            if not change['new']:
                print("Upload cancelled.")
                return

            # Get the uploaded file
            uploaded_file = list(uploaded_files.values())[0]
            filename = uploaded_file['metadata']['name']
            content = uploaded_file['content']

            # Define a path to save the image in the Colab runtime
            self.image_path = f"/content/{filename}"

            # Save the file
            with open(self.image_path, 'wb') as f:
                f.write(content)

            print(f"Image '{filename}' uploaded successfully.")

            # Display the uploaded image
            img = PIL.Image.open(self.image_path)
            display(img)

            # Enable the other widgets
            self.text_prompt.disabled = False
            self.run_button.disabled = False

    def _run_pipeline(self, b):
        """Called when the 'Run Segmentation' button is clicked."""
        with self.output:
            # Clear previous results but keep the uploaded image
            clear_output(wait=True)
            if self.image_path:
                img_display = PIL.Image.open(self.image_path)
                display(img_display)

            prompt_text = self.text_prompt.value
            if not prompt_text:
                print("Error: Please enter a text prompt.")
                return

            print("Processing... This may take a moment.")

            # --- DINO-SAM Pipeline ---
            try:
                image_pil = PIL.Image.open(self.image_path).convert("RGB")

                # 1. Get bounding boxes from DINO
                bounding_boxes, box_labels = self.dino._get_bounding_box(image_pil, prompt_text)

                # 2. Get masks from SAM
                if len(bounding_boxes) > 0:
                    seg_masks = self.sam._segment(image_pil, bounding_boxes)
                else:
                    seg_masks = None

                # 3. Visualize
                Visualize._show_segmentation(image_pil, prompt_text, seg_masks, bounding_boxes, box_labels)
            except Exception as e:
                print(f"An error occurred: {e}")

    def display_ui(self):
        """Displays the full user interface."""
        print("Please upload an image to begin.")
        display(self.uploader, self.text_prompt, self.run_button, self.output)



# --- Usage ---
dino_instance = DINO(mparams=ModelParameters())
sam_instance = SAM(mparams=ModelParameters())
ui = UI_Handler(dino_model=dino_instance, sam_model=sam_instance)
ui.display_ui()

GroundingDINO model loaded.
SAM model loaded.
Please upload an image to begin.


FileUpload(value={}, accept='image/*', description='Upload Image')

Text(value='a red cap', description='Prompt:', disabled=True, placeholder='Enter object(s) to find, separated …

Button(button_style='success', description='Run Segmentation', disabled=True, icon='check', style=ButtonStyle(…

Output()