<a href="https://colab.research.google.com/github/IoT-gamer/segment-anything-dinov3-onnx/blob/main/notebooks/dinov3_one_shot_segmentation_onnx.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# One-shot Segmentation by Feature Matching using DINOv3

- The DINOv3 features are so powerful that they can identify semantically similar regions across different images without any extra training.

## Core Concept
1. **Create a "Fingerprint":** Extract all the DINOv3 feature vectors from your single reference image. Using its alpha mask, isolate only the features that belong to the foreground object. The average of these features becomes the "fingerprint" or prototype vector for your target object.

2. **Scan for Matches:** Extract all the feature vectors from the test image.

3. **Compare:** Calculate the similarity (specifically, the cosine similarity) between the object's prototype vector and every single feature vector in the test image.

4. **Visualize:** This comparison results in a similarity map where high values indicate a strong match to your reference object. This map is your final segmentation.

## Prerequisites
- This notebook uses DINOv3 feature extractor model exported to ONNX using [dinov3_onnx_export.ipynb](https://github.com/IoT-gamer/segment-anything-dinov3-onnx/blob/main/notebooks/dinov3_onnx_export.ipynb)
- Reference image should have segmented mask in alpha-channel of RBGA .png file
  - for example, use [flutter_segment_anything_app](https://github.com/IoT-gamer/flutter_segment_anything_app)

## Load Model and Reference Image
- Load the ONNX feature extractor session and your single reference RGBA image.

### Dependencies

In [None]:
!pip install onnxruntime

In [None]:
# Uncomment if storing model and/or images in google drive and using google colab
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
import onnxruntime as ort
from PIL import Image
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt

### Constants

In [None]:
IMAGE_SIZE = 768
PATCH_SIZE = 16
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

### Load the ONNX Feature Extractor

- Export DINOv3 feature extractor onnx model using this notebook:
  - [dinov3_onnx_export.ipynb](https://github.com/IoT-gamer/segment-anything-dinov3-onnx/blob/main/notebooks/dinov3_onnx_export.ipynb)

In [None]:
onnx_feature_extractor_path = "dinov3_feature_extractor.onnx"
feature_extractor_session = ort.InferenceSession(onnx_feature_extractor_path)

### Load Reference Image
- Reference image should be an RBGA .png file with mask stored in alpha-channel

In [None]:
ref_image_path = "/path/to/your/single_reference_rgba.png"
ref_image_rgba = Image.open(ref_image_path)

# Separate image and mask
ref_image_rgb = ref_image_rgba.convert("RGB")
ref_mask = ref_image_rgba.split()[-1]

## Create the Object Prototype
- Process the reference image to create the average feature vector for the foreground object

### Image Preprocessing and Get Features

In [None]:
def preprocess_for_prototyping(img_pil, mask_pil, image_size, patch_size):
    # Preprocess the RGB image
    w, h = img_pil.size
    h_patches = image_size // patch_size
    w_patches = int((w * image_size) / (h * patch_size))
    if w_patches % 2 != 0: w_patches -= 1

    new_h, new_w = h_patches * patch_size, w_patches * patch_size
    resized_img = img_pil.resize((new_w, new_h), Image.Resampling.BICUBIC)

    img_np = np.array(resized_img, dtype=np.float32) / 255.0
    mean = np.array(IMAGENET_MEAN, dtype=np.float32)
    std = np.array(IMAGENET_STD, dtype=np.float32)
    normalized_img = (img_np - mean) / std
    input_tensor = normalized_img.transpose(2, 0, 1)[np.newaxis, :, :]

    # Downsample the mask to match the patch grid
    resized_mask = mask_pil.resize((w_patches, h_patches), Image.Resampling.NEAREST)
    mask_np = np.array(resized_mask, dtype=np.float32) / 255.0
    patch_mask = (mask_np > 0.5).flatten() # Flatten to a 1D boolean array

    return input_tensor, patch_mask, (h_patches, w_patches)

# Preprocess and get features
ref_input_tensor, ref_patch_mask, _ = preprocess_for_prototyping(ref_image_rgb, ref_mask, IMAGE_SIZE, PATCH_SIZE)
ref_inputs = {feature_extractor_session.get_inputs()[0].name: ref_input_tensor}
ref_features = feature_extractor_session.run(None, ref_inputs)[0].squeeze()

### Create the prototype

In [None]:
# Select only the features corresponding to the foreground mask
foreground_features = ref_features[ref_patch_mask]
# Calculate the mean feature vector (the "fingerprint")
object_prototype = np.mean(foreground_features, axis=0, keepdims=True)

print(f"Created object prototype with shape: {object_prototype.shape}")
# Expected output: Created object prototype with shape: (1, 384) for vits16

## Process a Test Image and Compute Similarity
-  Load a new test image, extract its features, and compute the cosine similarity against the prototype

### Load and Preprocess Test Image

In [None]:
test_image_path = "/path/to/your/test_image.jpg"
test_image = Image.open(test_image_path).convert('RGB')

def preprocess_image_numpy(img_pil, image_size, patch_size):
    """Resizes and normalizes an image using NumPy and Pillow."""
    w, h = img_pil.size
    h_patches = image_size // patch_size
    w_patches = int((w * image_size) / (h * patch_size))
    if w_patches % 2 != 0: w_patches -= 1

    new_h, new_w = h_patches * patch_size, w_patches * patch_size
    resized_img = img_pil.resize((new_w, new_h), Image.Resampling.BICUBIC)

    img_np = np.array(resized_img, dtype=np.float32) / 255.0
    mean = np.array(IMAGENET_MEAN, dtype=np.float32)
    std = np.array(IMAGENET_STD, dtype=np.float32)
    normalized_img = (img_np - mean) / std

    input_tensor = normalized_img.transpose(2, 0, 1)[np.newaxis, :, :]
    return input_tensor, (h_patches, w_patches), resized_img

test_input_tensor, patch_dims, resized_test_image = preprocess_image_numpy(test_image, IMAGE_SIZE, PATCH_SIZE)
h_patches, w_patches = patch_dims



### Get Test Image Features

In [None]:
test_inputs = {feature_extractor_session.get_inputs()[0].name: test_input_tensor}
test_features = feature_extractor_session.run(None, test_inputs)[0].squeeze()

### Compute Similarity Map

In [None]:
# Compare the object prototype to every patch feature in the test image
similarity_scores = cosine_similarity(test_features, object_prototype)
similarity_map = similarity_scores.reshape(h_patches, w_patches)

### Visualize the Result


In [None]:
plt.figure(figsize=(10, 5), dpi=150)
plt.subplot(1, 2, 1)
plt.imshow(resized_test_image)
plt.title('Test Image'); plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(similarity_map, cmap='viridis')
plt.title('Object Similarity Map'); plt.axis('off')
plt.tight_layout()
plt.show()

## Further Improvements
- **Segment Anything:** feed x-y coordinates from the similarity map into a segment-anything model [edgetam_onnx_export.ipynb](https://github.com/IoT-gamer/segment-anything-dinov3-onnx/blob/main/notebooks/edgetam_onnx_export.ipynb)
- **Multi-Prototype Matching:** If your object has very distinct parts (e.g., a person with a red hat and a blue shirt), averaging all features into one prototype might dilute the signal. You could use an algorithm like K-Means on the foreground_features to find 2 or 3 cluster centers (prototypes). Then, for each patch in the test image, you would find its similarity to the closest of these prototypes. This often yields sharper results.