<a href="https://colab.research.google.com/github/IoT-gamer/segment-anything-dinov3-onnx/blob/main/notebooks/foreground_segmentation_onnx_export.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Training a Foreground Segmentation Tool with DINOv3 & Export Model in ONNX Format

- train a foreground segmentation model using DINOv3 features.
- export DINOv3 Feature Extractor to ONNX
- export Logistic Regression Classifier to ONNX

## Acknoweldgements/References
- [DINOv3 github repo](https://github.com/facebookresearch/dinov3)
- [foreground_segmentation.ipynb](https://github.com/facebookresearch/dinov3/blob/main/notebooks/foreground_segmentation.ipynb)

## Setup

In [None]:
# Uncomment if storing model and/or images in google drive and using google colab
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
!pip install onnx onnxruntime

In [None]:
from google.colab import userdata

import io
import os
import pickle
import tarfile
import urllib
import glob
from pathlib import Path

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score
from sklearn.linear_model import LogisticRegression
import torch
import torchvision.transforms.functional as TF
from tqdm import tqdm

import onnx
import onnxruntime as ort

## Model Loading
- visit https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/ to get weights URLs
  - note: URLs may expire after a few days
- two options:
  1. download the weigths and store locally (recommended)
  2. add a secret named `dinov3_vits16` and copy and past the dinov3_vits16 URL in the value.

In [None]:
MODEL_DINOV3_VITS = "dinov3_vits16"
MODEL_NAME = MODEL_DINOV3_VITS

DEVICE = "cpu" # Select cuda or cpu

# Set the path to your local weights file. If `None`, it will download from the URL.
# For example: WEIGHTS_LOCAL_PATH = "/path/to/dinov3_vits16.pt"

# WEIGHTS_LOCAL_PATH = None
WEIGHTS_LOCAL_PATH = "/path/to/dinov3_vits16_pretrain_lvd1689m-08c60483.pth"

if WEIGHTS_LOCAL_PATH and os.path.exists(WEIGHTS_LOCAL_PATH):
    print(f"Loading weights from local path: {WEIGHTS_LOCAL_PATH}")
    weights_source = WEIGHTS_LOCAL_PATH
else:
    print("Loading weights from remote URL.")
    # Add a secret named `dinov3_vits16` in Colab with the URL as the value
    try:
        WEIGHTS_URL = userdata.get('dinov3_vits16')
        weights_source = WEIGHTS_URL
    except Exception as e:
        print(f"Could not retrieve weights URL from Colab secrets. Please set WEIGHTS_LOCAL_PATH. Error: {e}")
        weights_source = None # Will cause an error if not set

# Load the model
model = torch.hub.load(
    repo_or_dir="facebookresearch/dinov3",
    model=MODEL_NAME,
    source="github",
    weights=weights_source
)

if DEVICE == "cuda":
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device)
model.eval()


## Data Loading
- RGBA .png images with masks stored in alpha channels

In [None]:
def load_rgba_data(data_dir: str) -> tuple[list[Image.Image], list[Image.Image]]:
    """
    Loads all RGBA .png images from a local directory.

    Args:
        data_dir: Path to the directory containing .png files.

    Returns:
        A tuple containing two lists:
        - A list of RGB images (PIL.Image.Image).
        - A list of Alpha channel masks (PIL.Image.Image).
    """
    images = []
    masks = []

    # Use glob to find all .png files in the directory
    image_paths = glob.glob(os.path.join(data_dir, "*.png"))
    if not image_paths:
        raise FileNotFoundError(f"No .png files found in the directory: {data_dir}")

    print(f"Found {len(image_paths)} .png files in {data_dir}")

    for image_path in tqdm(image_paths, desc="Loading local images"):
        with Image.open(image_path) as img:
            if img.mode != 'RGBA':
                print(f"Warning: Image {Path(image_path).name} is not in RGBA mode. Skipping.")
                continue

            # Ensure data is loaded into memory before closing the file
            img.load()

            # Separate RGB and Alpha channels
            rgb_image = img.convert("RGB")
            alpha_mask = img.split()[-1]

            images.append(rgb_image)
            masks.append(alpha_mask)

    return images, masks

## Data Source

In [None]:
LOCAL_DATA_PATH = "/path/to/your/rgba_png_folder"

images, labels = load_rgba_data(LOCAL_DATA_PATH)

n_images = len(images)
print(f"Loaded {n_images} images and labels.")

## Visualize an Example

In [None]:
if n_images > 0:
    data_index = 0
    print(f"Showing image / mask at index {data_index}:")
    image = images[data_index]

    mask = labels[data_index]
    foreground = Image.composite(image, mask, mask)

    plt.figure(figsize=(12, 4), dpi=150)
    plt.subplot(1, 3, 1)
    plt.imshow(image)
    plt.title("Image")
    plt.axis('off')
    plt.subplot(1, 3, 2)
    plt.imshow(mask, cmap='gray')
    plt.title("Mask (Alpha Channel)")
    plt.axis('off')
    plt.subplot(1, 3, 3)
    plt.imshow(foreground)
    plt.title("Foreground")
    plt.axis('off')
    plt.show()

## Building Per-Patch Label Map
Since our models run with a patch size of 16, we have to quantize the ground truth to a 16x16 pixels grid. To achieve this, we define:
- the resize transform to resize an image such that it aligns well with the 16x16 grid;
- a uniform 16x16 conv layer as a [box blur filter](https://en.wikipedia.org/wiki/Box_blur) with stride equal to the patch size.

In [None]:
PATCH_SIZE = 16
IMAGE_SIZE = 768

patch_quant_filter = torch.nn.Conv2d(1, 1, PATCH_SIZE, stride=PATCH_SIZE, bias=False)
patch_quant_filter.weight.data.fill_(1.0 / (PATCH_SIZE * PATCH_SIZE))

def resize_transform(
    img_or_mask: Image,
    image_size: int = IMAGE_SIZE,
    patch_size: int = PATCH_SIZE,
) -> torch.Tensor:
    w, h = img_or_mask.size
    h_patches = int(image_size / patch_size)
    w_patches = int((w * image_size) / (h * patch_size))
    target_size = (h_patches * patch_size, w_patches * patch_size)
    return TF.to_tensor(TF.resize(img_or_mask, target_size))

## Extracting Features and Labels for All Images

In [None]:
xs = []
ys = []
image_index = []

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

MODEL_TO_NUM_LAYERS = {MODEL_DINOV3_VITS: 12}
n_layers = MODEL_TO_NUM_LAYERS[MODEL_NAME]

with torch.inference_mode():
    with torch.autocast(device_type=device.type, dtype=torch.float32):
        for i in tqdm(range(n_images), desc="Processing images"):
            mask_i = labels[i]
            if mask_i.mode == 'RGBA':
                mask_i = mask_i.split()[-1]

            mask_i_resized = resize_transform(mask_i)
            mask_i_quantized = patch_quant_filter(mask_i_resized).squeeze().view(-1).detach().cpu()
            ys.append(mask_i_quantized)

            # Loading the image data (already in RGB)
            image_i = images[i]
            image_i_resized = resize_transform(image_i)
            image_i_normalized = TF.normalize(image_i_resized, mean=IMAGENET_MEAN, std=IMAGENET_STD)
            if device == torch.device('cuda'):
                image_i_tensor = image_i_normalized.unsqueeze(0).cuda()
            else:
                image_i_tensor = image_i_normalized.unsqueeze(0)
            # image_i_tensor = image_i_normalized.unsqueeze(0).cuda()

            feats = model.get_intermediate_layers(image_i_tensor, n=range(n_layers), reshape=True, norm=True)
            dim = feats[-1].shape[1]
            xs.append(feats[-1].squeeze().view(dim, -1).permute(1,0).detach().cpu())

            image_index.append(i * torch.ones(ys[-1].shape))


# Concatenate all lists into torch tensors
xs = torch.cat(xs)
ys = torch.cat(ys)
image_index = torch.cat(image_index)

# Keep only patches with clear positive or negative labels
idx = (ys < 0.01) | (ys > 0.99)
xs = xs[idx]
ys = ys[idx]
image_index = image_index[idx]

print("Design matrix size: ", xs.shape)
print("Label matrix size: ", ys.shape)

## Training a Classifier and Model Selection
- We will do leave-one-out, and consecutively exclude each image as a validation set. We'll try 8 values of C ranging from 1e-7 to 1e-0.

In [None]:
cs = np.logspace(-7, 0, 8)
scores = np.zeros((n_images, len(cs)))

for i in range(n_images):
    print(f'Validation using image {i+1}/{n_images}')
    train_selection = image_index != float(i)
    fold_x = xs[train_selection].numpy()
    fold_y = (ys[train_selection] > 0).long().numpy()
    val_x = xs[~train_selection].numpy()
    val_y = (ys[~train_selection] > 0).long().numpy()

    for j, c in enumerate(cs):
        clf = LogisticRegression(random_state=0, C=c, max_iter=10000, n_jobs=-1).fit(fold_x, fold_y)
        output = clf.predict_proba(val_x)
        s = average_precision_score(val_y, output[:, 1])
        scores[i, j] = s

### Choosing the Best C

In [None]:
best_c_index = scores.mean(axis=0).argmax()
best_c = cs[best_c_index]
print(f"\nBest C value found: {best_c:.2e}")

### Retraining with the optimal regularization

In [None]:
print("Retraining classifier with the optimal C on all data...")
clf = LogisticRegression(random_state=0, C=best_c, max_iter=100000, verbose=0, n_jobs=-1).fit(xs.numpy(), (ys > 0).long().numpy())
print("Final classifier trained.")

### Saving the PyTorch Model

In [None]:
save_root = '.'
model_path = os.path.join(save_root, "fg_classifier.pkl")
with open(model_path, "wb") as f:
  pickle.dump(clf, f)
print(f"Classifier saved to {model_path}")

# ONNX Model Export

## Class Wrappers

In [None]:
class DinoV3FeatureExtractor(torch.nn.Module):
    def __init__(self, model, n_layers):
        super().__init__()
        self.model = model
        self.n_layers = n_layers

    def forward(self, x):
        features_list = self.model.get_intermediate_layers(x, n=range(self.n_layers), reshape=True, norm=True)
        last_layer_features = features_list[-1]
        B, C, H, W = last_layer_features.shape
        return last_layer_features.view(B, C, -1).permute(0, 2, 1)

class LogisticRegressionONNX(torch.nn.Module):
    def __init__(self, sklearn_classifier):
        super().__init__()
        self.coeffs = torch.nn.Parameter(torch.from_numpy(sklearn_classifier.coef_).float())
        self.intercept = torch.nn.Parameter(torch.from_numpy(sklearn_classifier.intercept_).float())

    def forward(self, x):
        linear_output = torch.matmul(x, self.coeffs.T) + self.intercept
        probabilities = torch.sigmoid(linear_output)
        return probabilities.squeeze(-1)


## Export DINOv3 Feature Extractor

In [None]:
onnx_feature_extractor_path = "dinov3_feature_extractor.onnx"
device = "cpu"
onnx_exportable_dino = DinoV3FeatureExtractor(model, n_layers).to(device).eval()
dummy_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE, device=device)

print(f"Exporting DINOv3 feature extractor to {onnx_feature_extractor_path}...")
torch.onnx.export(
    onnx_exportable_dino, dummy_input, onnx_feature_extractor_path,
    input_names=['input_image'], output_names=['patch_features'],
    dynamic_axes={'input_image': {2: 'height', 3: 'width'}, 'patch_features': {1: 'num_patches'}},
    opset_version=17
)
print("DINOv3 ONNX export complete!")

## Export Classifier

In [None]:
onnx_classifier_path = "fg_classifier.onnx"
onnx_exportable_clf = LogisticRegressionONNX(clf).eval()
num_channels = xs.shape[1]
dummy_features = torch.randn(1, (IMAGE_SIZE // PATCH_SIZE)**2, num_channels)

print(f"\nExporting classifier to {onnx_classifier_path}...")
torch.onnx.export(
    onnx_exportable_clf, dummy_features, onnx_classifier_path,
    input_names=['patch_features'], output_names=['probabilities'],
    dynamic_axes={'patch_features': {1: 'num_patches'}, 'probabilities': {1: 'num_patches'}},
    opset_version=17
)
print("Classifier ONNX export complete!")

## PyTorch-Independent ONNX Inference

In [None]:
def preprocess_image_numpy(img_pil: Image.Image, image_size: int, patch_size: int):
    """Resizes and normalizes an image using NumPy and Pillow."""
    w, h = img_pil.size
    h_patches = image_size // patch_size
    w_patches = int((w * image_size) / (h * patch_size))
    if w_patches % 2 != 0: w_patches -= 1

    new_h, new_w = h_patches * patch_size, w_patches * patch_size
    resized_img = img_pil.resize((new_w, new_h), Image.Resampling.BICUBIC)

    img_np = np.array(resized_img, dtype=np.float32) / 255.0
    mean = np.array(IMAGENET_MEAN, dtype=np.float32)
    std = np.array(IMAGENET_STD, dtype=np.float32)
    normalized_img = (img_np - mean) / std

    input_tensor = normalized_img.transpose(2, 0, 1)[np.newaxis, :, :]
    return input_tensor, (h_patches, w_patches), resized_img

## Load Test Image

In [None]:
test_image_fpath = "/path/to/your/test_rgba.png"
test_image = Image.open(test_image_fpath).convert('RGB')

## Preprocess and Run Inference

In [None]:
input_tensor, patch_dims, resized_image_pil = preprocess_image_numpy(test_image, IMAGE_SIZE, PATCH_SIZE)
h_patches, w_patches = patch_dims

feature_extractor_session = ort.InferenceSession(onnx_feature_extractor_path)
classifier_session = ort.InferenceSession(onnx_classifier_path)

feature_inputs = {feature_extractor_session.get_inputs()[0].name: input_tensor}
patch_features = feature_extractor_session.run(None, feature_inputs)[0]

classifier_inputs = {classifier_session.get_inputs()[0].name: patch_features}
fg_probabilities = classifier_session.run(None, classifier_inputs)[0]

## Post-process and Visualize

In [None]:
fg_score = fg_probabilities.reshape(h_patches, w_patches)
fg_score_mf = signal.medfilt2d(fg_score, kernel_size=3)

print("Displaying final segmentation results...")
plt.figure(figsize=(12, 4), dpi=150)
plt.subplot(1, 3, 1)
plt.imshow(resized_image_pil)
plt.title('Input Image'); plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(fg_score, cmap='viridis')
plt.title('Foreground Score'); plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow(fg_score_mf, cmap='viridis')
plt.title('+ Median Filter'); plt.axis('off')
plt.tight_layout()
plt.show()