# Grounding DINO Zero-Shot Redaction Detection

This notebook tests IDEA-Research/grounding-dino-base for zero-shot detection of black boxes/redactions in document images.

**Usage:**
1. Run Cell 1 to install dependencies
2. Run Cell 2 to load the model
3. Update `IMAGE_PATH` in Cell 3 to point to your image
4. Run remaining cells to detect and visualize redactions

In [None]:
# Cell 1: Install Dependencies
!pip install transformers torch torchvision

In [None]:
# Cell 2: Imports & Model Loading
import torch
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np

# Load Grounding DINO model and processor
model_id = "IDEA-Research/grounding-dino-base"
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Loading model on {device}...")
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
print("Model loaded successfully!")

In [None]:
# Cell 3: Image Loading
# Update this path to point to your image
IMAGE_PATH = "/path/to/your/image.png"

image = Image.open(IMAGE_PATH).convert("RGB")
print(f"Image size: {image.size}")

# Display the original image
plt.figure(figsize=(12, 8))
plt.imshow(image)
plt.title("Original Image")
plt.axis("off")
plt.show()

In [None]:
# Cell 4: Inference with Multiple Prompts

# Detection thresholds
box_threshold = 0.25  # Confidence threshold for keeping boxes
text_threshold = 0.25  # Confidence threshold for text-box matching

# Prompts to test for redaction detection
prompts = [
    "black box",
    "redaction",
    "redacted text",
    "black rectangle",
]

def run_detection(image, text_prompt, box_thresh, text_thresh):
    """Run Grounding DINO detection with given prompt and thresholds."""
    inputs = processor(images=image, text=text_prompt, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    results = processor.post_process_grounded_object_detection(
        outputs,
        inputs.input_ids,
        box_threshold=box_thresh,
        text_threshold=text_thresh,
        target_sizes=[image.size[::-1]]  # (height, width)
    )
    
    return results[0]

# Run detection for each prompt
all_results = {}
for prompt in prompts:
    result = run_detection(image, prompt, box_threshold, text_threshold)
    all_results[prompt] = result
    num_boxes = len(result["boxes"])
    print(f"Prompt: '{prompt}' -> {num_boxes} detection(s)")
    if num_boxes > 0:
        for i, (box, score, label) in enumerate(zip(result["boxes"], result["scores"], result["labels"])):
            print(f"  [{i}] score={score:.3f}, box={box.tolist()}")

In [None]:
# Cell 5: Visualization

def visualize_detections(image, results_dict, figsize=(16, 12)):
    """Visualize detections from all prompts in a grid."""
    num_prompts = len(results_dict)
    cols = 2
    rows = (num_prompts + 1) // 2
    
    fig, axes = plt.subplots(rows, cols, figsize=figsize)
    axes = axes.flatten() if num_prompts > 1 else [axes]
    
    # Color map based on confidence
    cmap = plt.cm.RdYlGn  # Red (low) to Green (high)
    
    for idx, (prompt, result) in enumerate(results_dict.items()):
        ax = axes[idx]
        ax.imshow(image)
        ax.set_title(f"Prompt: '{prompt}' ({len(result['boxes'])} detections)")
        ax.axis("off")
        
        boxes = result["boxes"]
        scores = result["scores"]
        
        for box, score in zip(boxes, scores):
            x1, y1, x2, y2 = box.tolist()
            width = x2 - x1
            height = y2 - y1
            
            # Color based on confidence score
            color = cmap(score.item())
            
            rect = patches.Rectangle(
                (x1, y1), width, height,
                linewidth=2,
                edgecolor=color,
                facecolor="none"
            )
            ax.add_patch(rect)
            
            # Add score label
            ax.text(
                x1, y1 - 5,
                f"{score:.2f}",
                color=color,
                fontsize=10,
                fontweight="bold",
                bbox=dict(boxstyle="round,pad=0.2", facecolor="white", alpha=0.7)
            )
    
    # Hide unused subplots
    for idx in range(len(results_dict), len(axes)):
        axes[idx].axis("off")
    
    plt.tight_layout()
    plt.show()

# Visualize all results
visualize_detections(image, all_results)

In [None]:
# Cell 6 (Optional): Single Prompt with Custom Threshold
# Use this cell to experiment with a specific prompt and threshold

CUSTOM_PROMPT = "black box"  # Change this
CUSTOM_BOX_THRESHOLD = 0.20  # Lower = more detections, higher = fewer but more confident
CUSTOM_TEXT_THRESHOLD = 0.20

custom_result = run_detection(image, CUSTOM_PROMPT, CUSTOM_BOX_THRESHOLD, CUSTOM_TEXT_THRESHOLD)

print(f"Detected {len(custom_result['boxes'])} boxes with prompt '{CUSTOM_PROMPT}'")

# Visualize single result
fig, ax = plt.subplots(figsize=(14, 10))
ax.imshow(image)
ax.set_title(f"'{CUSTOM_PROMPT}' (box_thresh={CUSTOM_BOX_THRESHOLD}, text_thresh={CUSTOM_TEXT_THRESHOLD})")
ax.axis("off")

cmap = plt.cm.RdYlGn
for box, score in zip(custom_result["boxes"], custom_result["scores"]):
    x1, y1, x2, y2 = box.tolist()
    color = cmap(score.item())
    rect = patches.Rectangle(
        (x1, y1), x2 - x1, y2 - y1,
        linewidth=3,
        edgecolor=color,
        facecolor="none"
    )
    ax.add_patch(rect)
    ax.text(x1, y1 - 5, f"{score:.2f}", color=color, fontsize=12, fontweight="bold",
            bbox=dict(boxstyle="round,pad=0.2", facecolor="white", alpha=0.7))

plt.tight_layout()
plt.show()