
<a href="https://colab.research.google.com/github/facebookresearch/dinov3/blob/main/dinov3_zero_shot_segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


This notebook demonstrates how to perform zero-shot segmentation using DINOv3 features based on patch similarity. You can:

1. **Upload an image** to segment
2. **Click a seed point** on the image to define the region of interest
3. **Extract DINOv3 patch features** and calculate cosine similarity
4. **Adjust similarity threshold** interactively to refine the segmentation
5. **Visualize results** with similarity heatmaps and segmentation overlays

The approach leverages the rich semantic representations learned by DINOv3 to segment objects based on visual similarity to a user-selected seed region.


## 1. Environment Setup and Installation

In [None]:

!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install Pillow numpy matplotlib ipywidgets
!pip install scikit-image


In [None]:

import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.widgets import Button
from PIL import Image
import ipywidgets as widgets
from IPython.display import display, clear_output
from google.colab import files
import io
import base64
from skimage.transform import resize
from skimage.filters import gaussian
import subprocess
import os
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


## 2. Load DINOv3 Model

In [None]:

DINOV3_REPO_DIR = "/content/dinov3"
MODEL_NAME = "dinov3_vitl16"
PATCH_SIZE = 16
FEATURE_DIM = 1024  # ViT-L feature dimension
N_LAYERS = 24       # ViT-L has 24 layers

print("Cloning DINOv3 repository...")
if not os.path.exists(DINOV3_REPO_DIR):
    try:
        subprocess.run(["git", "clone", "https://github.com/facebookresearch/dinov3.git", DINOV3_REPO_DIR], 
                      check=True, capture_output=True, text=True)
        print("Repository cloned successfully!")
    except subprocess.CalledProcessError as e:
        print(f"Failed to clone repository: {e}")
        print("Falling back to GitHub loading (may encounter rate limits)...")
        DINOV3_REPO_DIR = "facebookresearch/dinov3"

print("Loading DINOv3 ViT-L/16 model...")
model = torch.hub.load(
    repo_or_dir=DINOV3_REPO_DIR,
    model=MODEL_NAME,
    source="local" if os.path.exists(DINOV3_REPO_DIR) else "github"
)
model = model.to(device)
model.eval()
print("Model loaded successfully!")
print(f"Model architecture: {MODEL_NAME}")
print(f"Patch size: {PATCH_SIZE}x{PATCH_SIZE}")
print(f"Feature dimension: {FEATURE_DIM}")


## 3. Define Constants and Helper Functions

In [None]:

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

IMAGE_SIZE = 768  # 768 = 16 * 48 patches

def resize_to_patch_aligned(image, target_size=IMAGE_SIZE, patch_size=PATCH_SIZE):
    """
    Resize image to be aligned with patch grid while maintaining aspect ratio.
    """
    if isinstance(image, Image.Image):
        w, h = image.size
        image = TF.to_tensor(image)
    else:
        _, h, w = image.shape
    
    aspect_ratio = w / h
    if aspect_ratio > 1:  # Width > Height
        target_w = target_size
        target_h = int(target_size / aspect_ratio)
    else:  # Height >= Width
        target_h = target_size
        target_w = int(target_size * aspect_ratio)
    
    target_h = (target_h // patch_size) * patch_size
    target_w = (target_w // patch_size) * patch_size
    
    target_h = max(target_h, patch_size)
    target_w = max(target_w, patch_size)
    
    return TF.resize(image, (target_h, target_w))

def normalize_image(image):
    """
    Normalize image with ImageNet statistics.
    """
    return TF.normalize(image, mean=IMAGENET_MEAN, std=IMAGENET_STD)

def pixel_to_patch_coords(pixel_x, pixel_y, patch_size=PATCH_SIZE):
    """
    Convert pixel coordinates to patch coordinates.
    """
    patch_x = pixel_x // patch_size
    patch_y = pixel_y // patch_size
    return int(patch_x), int(patch_y)

def extract_features(model, image_tensor, n_layers=N_LAYERS):
    """
    Extract dense features from DINOv3 model.
    """
    with torch.no_grad():
        with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', dtype=torch.float32):
            features = model.get_intermediate_layers(
                image_tensor.unsqueeze(0).to(device),
                n=range(n_layers),
                reshape=True,
                norm=True
            )
            return features[-1].squeeze(0)  # Remove batch dimension

def calculate_cosine_similarity(features, seed_feature):
    """
    Calculate cosine similarity between seed feature and all patch features.
    
    Args:
        features: (D, H, W) - patch features
        seed_feature: (D,) - seed patch feature vector
    
    Returns:
        similarity_map: (H, W) - cosine similarity scores
    """
    D, H, W = features.shape
    
    features_flat = features.view(D, -1)
    
    features_norm = F.normalize(features_flat, dim=0)  # Normalize each spatial location
    seed_norm = F.normalize(seed_feature.unsqueeze(0), dim=1)  # (1, D)
    
    similarity_flat = torch.mm(seed_norm, features_norm)  # (1, H*W)
    
    similarity_map = similarity_flat.view(H, W)
    
    return similarity_map

print("Helper functions defined successfully!")


## 4. Image Upload and Preprocessing

In [None]:

print("Please upload an image file:")
uploaded = files.upload()

if uploaded:
    filename = list(uploaded.keys())[0]
    image_data = uploaded[filename]
    
    original_image = Image.open(io.BytesIO(image_data)).convert('RGB')
    print(f"Original image size: {original_image.size}")
    
    processed_image = resize_to_patch_aligned(original_image)
    normalized_image = normalize_image(processed_image)
    
    _, height, width = processed_image.shape
    patch_h = height // PATCH_SIZE
    patch_w = width // PATCH_SIZE
    
    print(f"Processed image size: {width}x{height}")
    print(f"Patch grid: {patch_w}x{patch_h} patches")
    
    plt.figure(figsize=(10, 8))
    plt.imshow(processed_image.permute(1, 2, 0))
    plt.title(f"Uploaded Image (resized to {width}x{height})")
    plt.axis('off')
    plt.show()
    
else:
    print("No image uploaded. Please run this cell again and upload an image.")


## 5. Extract DINOv3 Features

In [None]:

if 'normalized_image' in locals():
    print("Extracting DINOv3 features...")
    
    features = extract_features(model, normalized_image)
    
    print(f"Feature shape: {features.shape}")
    print(f"Feature dimension: {features.shape[0]}")
    print(f"Spatial resolution: {features.shape[1]}x{features.shape[2]}")
    
    features_cpu = features.cpu()
    
    print("Feature extraction completed!")
else:
    print("Please upload an image first in the previous cell.")




Click on the image below to select a seed point. The red circle will show your selected location.


In [None]:

if 'processed_image' in locals() and 'features_cpu' in locals():
    seed_coords = None
    seed_feature = None
    
    def on_click(event):
        global seed_coords, seed_feature
        
        if event.inaxes is None:
            return
        
        x, y = int(event.xdata), int(event.ydata)
        
        patch_x, patch_y = pixel_to_patch_coords(x, y)
        
        patch_x = max(0, min(patch_x, patch_w - 1))
        patch_y = max(0, min(patch_y, patch_h - 1))
        
        seed_coords = (patch_x, patch_y)
        
        seed_feature = features_cpu[:, patch_y, patch_x]  # (D,)
        
        ax.clear()
        ax.imshow(processed_image.permute(1, 2, 0))
        
        circle = patches.Circle((x, y), radius=8, color='red', fill=True, alpha=0.8)
        ax.add_patch(circle)
        
        patch_rect = patches.Rectangle(
            (patch_x * PATCH_SIZE, patch_y * PATCH_SIZE),
            PATCH_SIZE, PATCH_SIZE,
            linewidth=2, edgecolor='red', facecolor='none'
        )
        ax.add_patch(patch_rect)
        
        ax.set_title(f"Seed Point: Pixel ({x}, {y}) → Patch ({patch_x}, {patch_y})")
        ax.axis('off')
        
        plt.draw()
        
        print(f"Selected seed point: Pixel ({x}, {y}) → Patch ({patch_x}, {patch_y})")
        print(f"Seed feature shape: {seed_feature.shape}")
    
    fig, ax = plt.subplots(figsize=(12, 10))
    ax.imshow(processed_image.permute(1, 2, 0))
    ax.set_title("Click on the image to select a seed point")
    ax.axis('off')
    
    cid = fig.canvas.mpl_connect('button_press_event', on_click)
    
    plt.tight_layout()
    plt.show()
    
    print("\nInstructions:")
    print("1. Click anywhere on the image to select a seed point")
    print("2. The red circle shows your click location")
    print("3. The red rectangle shows the corresponding 16x16 patch")
    print("4. After selecting, run the next cell to calculate similarity")
    
else:
    print("Please upload an image and extract features first.")


## 7. Calculate Cosine Similarity

In [None]:

if 'seed_feature' in locals() and seed_feature is not None:
    print("Calculating cosine similarity between seed patch and all other patches...")
    
    similarity_map = calculate_cosine_similarity(features_cpu, seed_feature)
    
    print(f"Similarity map shape: {similarity_map.shape}")
    print(f"Similarity range: [{similarity_map.min():.3f}, {similarity_map.max():.3f}]")
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    axes[0].imshow(processed_image.permute(1, 2, 0))
    axes[0].set_title("Original Image")
    axes[0].axis('off')
    
    if seed_coords:
        seed_x, seed_y = seed_coords
        circle = patches.Circle(
            (seed_x * PATCH_SIZE + PATCH_SIZE//2, seed_y * PATCH_SIZE + PATCH_SIZE//2),
            radius=8, color='red', fill=True, alpha=0.8
        )
        axes[0].add_patch(circle)
    
    im1 = axes[1].imshow(similarity_map, cmap='viridis', vmin=-1, vmax=1)
    axes[1].set_title(f"Similarity Map ({patch_h}x{patch_w})")
    axes[1].axis('off')
    plt.colorbar(im1, ax=axes[1], shrink=0.8)
    
    similarity_upsampled = torch.nn.functional.interpolate(
        similarity_map.unsqueeze(0).unsqueeze(0),
        size=(height, width),
        mode='bilinear',
        align_corners=False
    ).squeeze()
    
    im2 = axes[2].imshow(similarity_upsampled, cmap='viridis', vmin=-1, vmax=1)
    axes[2].set_title(f"Upsampled Similarity ({height}x{width})")
    axes[2].axis('off')
    plt.colorbar(im2, ax=axes[2], shrink=0.8)
    
    plt.tight_layout()
    plt.show()
    
    print("\nSimilarity calculation completed!")
    print("Proceed to the next cell for interactive threshold adjustment.")
    
else:
    print("Please select a seed point first by clicking on the image in the previous cell.")




Use the slider below to adjust the similarity threshold and see the resulting segmentation mask in real-time.


In [None]:

if 'similarity_map' in locals():
    output = widgets.Output()
    
    def update_segmentation(threshold):
        with output:
            clear_output(wait=True)
            
            binary_mask = (similarity_map >= threshold).float()
            
            mask_upsampled = torch.nn.functional.interpolate(
                binary_mask.unsqueeze(0).unsqueeze(0),
                size=(height, width),
                mode='nearest'
            ).squeeze()
            
            mask_smooth = gaussian(mask_upsampled.numpy(), sigma=1.0)
            
            fig, axes = plt.subplots(2, 3, figsize=(18, 12))
            
            axes[0, 0].imshow(processed_image.permute(1, 2, 0))
            axes[0, 0].set_title("Original Image")
            axes[0, 0].axis('off')
            
            if seed_coords:
                seed_x, seed_y = seed_coords
                circle = patches.Circle(
                    (seed_x * PATCH_SIZE + PATCH_SIZE//2, seed_y * PATCH_SIZE + PATCH_SIZE//2),
                    radius=8, color='red', fill=True, alpha=0.8
                )
                axes[0, 0].add_patch(circle)
            
            im1 = axes[0, 1].imshow(similarity_upsampled, cmap='viridis', vmin=-1, vmax=1)
            axes[0, 1].axhline(y=height//2, color='white', linestyle='--', alpha=0.5)
            axes[0, 1].axvline(x=width//2, color='white', linestyle='--', alpha=0.5)
            axes[0, 1].set_title(f"Similarity Map (Threshold: {threshold:.2f})")
            axes[0, 1].axis('off')
            
            cbar1 = plt.colorbar(im1, ax=axes[0, 1], shrink=0.8)
            cbar1.ax.axhline(y=threshold, color='red', linewidth=2)
            
            axes[0, 2].imshow(mask_smooth, cmap='gray', vmin=0, vmax=1)
            axes[0, 2].set_title("Segmentation Mask")
            axes[0, 2].axis('off')
            
            axes[1, 0].imshow(processed_image.permute(1, 2, 0))
            axes[1, 0].imshow(similarity_upsampled, cmap='viridis', alpha=0.4, vmin=-1, vmax=1)
            axes[1, 0].set_title("Image + Similarity Overlay")
            axes[1, 0].axis('off')
            
            axes[1, 1].imshow(processed_image.permute(1, 2, 0))
            axes[1, 1].imshow(mask_smooth, cmap='Reds', alpha=0.5, vmin=0, vmax=1)
            axes[1, 1].set_title("Image + Mask Overlay")
            axes[1, 1].axis('off')
            
            image_np = processed_image.permute(1, 2, 0).numpy()
            segmented = image_np * mask_smooth[:, :, np.newaxis]
            axes[1, 2].imshow(segmented)
            axes[1, 2].set_title("Segmented Region")
            axes[1, 2].axis('off')
            
            plt.tight_layout()
            plt.show()
            
            mask_area = mask_upsampled.sum().item()
            total_area = mask_upsampled.numel()
            coverage = mask_area / total_area * 100
            
            print(f"Threshold: {threshold:.3f}")
            print(f"Mask coverage: {coverage:.1f}% ({int(mask_area)} / {total_area} pixels)")
            print(f"Similarity range in mask: [{similarity_upsampled[mask_upsampled > 0].min():.3f}, {similarity_upsampled[mask_upsampled > 0].max():.3f}]")
    
    threshold_slider = widgets.FloatSlider(
        value=0.5,
        min=-1.0,
        max=1.0,
        step=0.01,
        description='Threshold:',
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='500px')
    )
    
    interactive_widget = widgets.interactive(update_segmentation, threshold=threshold_slider)
    
    display(interactive_widget)
    display(output)
    
    print("\nInstructions:")
    print("• Move the threshold slider to adjust the segmentation")
    print("• Higher threshold = more selective (smaller regions)")
    print("• Lower threshold = less selective (larger regions)")
    print("• The red line on the colorbar shows the current threshold")
    
else:
    print("Please complete the similarity calculation first.")




Save the segmentation results to files.


In [None]:

if 'similarity_map' in locals():
    current_threshold = threshold_slider.value if 'threshold_slider' in locals() else 0.5
    
    binary_mask = (similarity_map >= current_threshold).float()
    mask_upsampled = torch.nn.functional.interpolate(
        binary_mask.unsqueeze(0).unsqueeze(0),
        size=(height, width),
        mode='nearest'
    ).squeeze()
    
    original_pil = TF.to_pil_image(processed_image)
    mask_pil = TF.to_pil_image(mask_upsampled)
    similarity_pil = TF.to_pil_image((similarity_upsampled + 1) / 2)  # Normalize to [0,1]
    
    overlay = Image.blend(original_pil.convert('RGBA'), 
                         Image.new('RGBA', original_pil.size, (255, 0, 0, 128)), 
                         0.3)
    overlay.paste(original_pil, mask=mask_pil.convert('L'))
    
    original_pil.save('dinov3_original.png')
    mask_pil.save('dinov3_mask.png')
    similarity_pil.save('dinov3_similarity.png')
    overlay.save('dinov3_overlay.png')
    
    print(f"Results saved with threshold {current_threshold:.3f}:")
    print("• dinov3_original.png - Original processed image")
    print("• dinov3_mask.png - Binary segmentation mask")
    print("• dinov3_similarity.png - Similarity heatmap")
    print("• dinov3_overlay.png - Image with mask overlay")
    
    files.download('dinov3_original.png')
    files.download('dinov3_mask.png')
    files.download('dinov3_similarity.png')
    files.download('dinov3_overlay.png')
    
else:
    print("No segmentation results to export. Please complete the workflow first.")




This notebook demonstrated DINOv3-based zero-shot segmentation using patch feature similarity. Key points:

1. **Loaded DINOv3 ViT-L/16** model with 1024-dimensional features
2. **Interactive seed point selection** by clicking on images
3. **Cosine similarity calculation** between patch features
4. **Real-time threshold adjustment** with interactive slider
5. **Multi-view visualization** of results

- **Zero-shot**: No training required, works on any image
- **Interactive**: Real-time feedback and adjustment
- **Semantic**: Leverages DINOv3's rich visual representations
- **Flexible**: Adjustable similarity thresholds

- **Click on distinctive regions** with clear visual features
- **Adjust threshold gradually** to find optimal segmentation
- **Try different seed points** for the same object
- **Higher thresholds** for more precise, smaller regions
- **Lower thresholds** for larger, more inclusive regions

- Multi-scale feature aggregation
- Multiple seed points
- Post-processing with CRF or GrabCut
- Integration with object detection

Feel free to experiment with different images and thresholds to explore the capabilities of DINOv3 features!
