<a href="https://colab.research.google.com/github/IoT-gamer/sam2-dinov3-onnx/blob/main/notebooks/foreground_segmentation_onnx_export.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Training a Foreground Segmentation Tool with DINOv3
## Export Model in ONNX Format

- train a linear foreground segmentation model using DINOv3 features.
- export DINOv3 Feature Extractor to ONNX
- export Logistic Regression Classifier to ONNX

## Acknoweldgements/References
- [DINOv3 github repo](https://github.com/facebookresearch/dinov3)
- [foreground_segmentation.ipynb](https://github.com/facebookresearch/dinov3/blob/main/notebooks/foreground_segmentation.ipynb)
  - Model Training taken from this notebook
  - The main modification is the export to ONNX section

### Setup

In [None]:
from google.colab import userdata

import io
import os
import pickle
import tarfile
import urllib

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score
from sklearn.linear_model import LogisticRegression
import torch
import torchvision.transforms.functional as TF
from tqdm import tqdm

### Model

- visit https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/ to get weights URLs
- add a secret named `dinov3_vits16` and copy and past the `dinov3_vits16` URL in the value
- note: the URLs may expire after a few days.
  - need to save or re-register

In [None]:
MODEL_DINOV3_VITS = "dinov3_vits16"
MODEL_NAME = MODEL_DINOV3_VITS

WEIGHTS_URL = userdata.get('dinov3_vits16') # URL is stored in colab secrets

model = torch.hub.load(
    repo_or_dir="facebookresearch/dinov3",
    model=MODEL_NAME,
    source="github",
    weights=WEIGHTS_URL
)
model.cuda()

### Data
Now that we have the model set up, let's load the training data. It consists of:

- images in `jpg` format:
```
https://dl.fbaipublicfiles.com/dinov3/notebooks/foreground_segmentation/foreground_segmentation_images.tar.gz
```

- and segmentation masks stored as alpha channels in `png` files:
```
https://dl.fbaipublicfiles.com/dinov3/notebooks/foreground_segmentation/foreground_segmentation_labels.tar.gz
```

In total, there are 9 training image / mask pairs.


In [None]:
IMAGES_URI = "https://dl.fbaipublicfiles.com/dinov3/notebooks/foreground_segmentation/foreground_segmentation_images.tar.gz"
LABELS_URI = "https://dl.fbaipublicfiles.com/dinov3/notebooks/foreground_segmentation/foreground_segmentation_labels.tar.gz"

def load_images_from_remote_tar(tar_uri: str) -> list[Image.Image]:
    images = []
    with urllib.request.urlopen(tar_uri) as f:
        tar = tarfile.open(fileobj=io.BytesIO(f.read()))
        for member in tar.getmembers():
            image_data = tar.extractfile(member)
            image = Image.open(image_data)
            images.append(image)
    return images

images = load_images_from_remote_tar(IMAGES_URI)
labels = load_images_from_remote_tar(LABELS_URI)
n_images = len(images)
assert n_images == len(labels), f"{len(images)=}, {len(labels)=}"

print(f"Loaded {n_images} images and labels")

Let's, for example, visualize the first image / mask pair:

In [None]:
data_index = 0

print(f"Showing image / mask at index {data_index}:")

image = images[data_index]
mask = labels[data_index]
foreground = Image.composite(image, mask, mask)
mask_bg_np = np.copy(np.array(mask))
mask_bg_np[:, :, 3] = 255 - mask_bg_np[:, :, 3]
mask_bg = Image.fromarray(mask_bg_np)
background = Image.composite(image, mask_bg, mask_bg)

data_to_show = [image, mask, foreground, background]
data_labels = ["Image", "Mask", "Foreground", "Background"]

plt.figure(figsize=(16, 4), dpi=300)
for i in range(len(data_to_show)):
    plt.subplot(1, len(data_to_show), i + 1)
    plt.imshow(data_to_show[i])
    plt.axis('off')
    plt.title(data_labels[i], fontsize=12)
plt.show()

### Building Per-Patch Label Map

Since our models run with a patch size of 16, we have to quantize the ground truth to a 16x16 pixels grid. To achieve this, we define:
- the resize transform to resize an image such that it aligns well with the 16x16 grid;
- a uniform 16x16 conv layer as a [box blur filter](https://en.wikipedia.org/wiki/Box_blur) with stride equal to the patch size.

In [None]:
PATCH_SIZE = 16
IMAGE_SIZE = 768

# quantization filter for the given patch size
patch_quant_filter = torch.nn.Conv2d(1, 1, PATCH_SIZE, stride=PATCH_SIZE, bias=False)
patch_quant_filter.weight.data.fill_(1.0 / (PATCH_SIZE * PATCH_SIZE))

# image resize transform to dimensions divisible by patch size
def resize_transform(
    mask_image: Image,
    image_size: int = IMAGE_SIZE,
    patch_size: int = PATCH_SIZE,
) -> torch.Tensor:
    w, h = mask_image.size
    h_patches = int(image_size / patch_size)
    w_patches = int((w * image_size) / (h * patch_size))
    return TF.to_tensor(TF.resize(mask_image, (h_patches * patch_size, w_patches * patch_size)))

Let's, for example, visualize the first mask before and after quantization:

In [None]:
mask_0 = labels[0].split()[-1]
mask_0_resized = resize_transform(mask_0)
with torch.no_grad():
    mask_0_quantized = patch_quant_filter(mask_0_resized).squeeze().detach().cpu()

plt.figure(figsize=(4, 2), dpi=300)
plt.subplot(1, 2, 1)
plt.imshow(mask_0)
plt.axis('off')
plt.title(f"Original Mask, Size {mask_0.size}", fontsize=5)
plt.subplot(1, 2, 2)
plt.imshow(mask_0_quantized)
plt.axis('off')
plt.title(f"Quantized Mask, Size {tuple(mask_0_quantized.shape)}", fontsize=5)
plt.show()

### Extracting Features and Labels for All the Images
Now we will loop over the 9 training images, and extract for each image the patch labels, as well as the patch features. That involves running the dense feature extraction of our model with :

```
with torch.no_grad():        
    feats = model.get_intermediate_layers(img, n=range(n_layers), reshape=True, norm=True)
    dim = feats[-1].shape[1]
    xs.append(feats[-1].squeeze().view(dim, -1).permute(1,0).detach().cpu())
```

In [None]:
xs = []
ys = []
image_index = []

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

MODEL_TO_NUM_LAYERS = {
    MODEL_DINOV3_VITS: 12,
    # MODEL_DINOV3_VITSP: 12,
    # MODEL_DINOV3_VITB: 12,
    # MODEL_DINOV3_VITL: 24,
    # MODEL_DINOV3_VITHP: 32,
    # MODEL_DINOV3_VIT7B: 40,
}

n_layers = MODEL_TO_NUM_LAYERS[MODEL_NAME]

with torch.inference_mode():
    with torch.autocast(device_type='cuda', dtype=torch.float32):
        for i in tqdm(range(n_images), desc="Processing images"):
            # Loading the ground truth
            mask_i = labels[i].split()[-1]
            mask_i_resized = resize_transform(mask_i)
            mask_i_quantized = patch_quant_filter(mask_i_resized).squeeze().view(-1).detach().cpu()
            ys.append(mask_i_quantized)
            # Loading the image data
            image_i = images[i].convert('RGB')
            image_i_resized = resize_transform(image_i)
            image_i_resized = TF.normalize(image_i_resized, mean=IMAGENET_MEAN, std=IMAGENET_STD)
            image_i_resized = image_i_resized.unsqueeze(0).cuda()

            feats = model.get_intermediate_layers(image_i_resized, n=range(n_layers), reshape=True, norm=True)
            dim = feats[-1].shape[1]
            xs.append(feats[-1].squeeze().view(dim, -1).permute(1,0).detach().cpu())

            image_index.append(i * torch.ones(ys[-1].shape))


# Concatenate all lists into torch tensors
xs = torch.cat(xs)
ys = torch.cat(ys)
image_index = torch.cat(image_index)

# keeping only the patches that have clear positive or negative label
idx = (ys < 0.01) | (ys > 0.99)
xs = xs[idx]
ys = ys[idx]
image_index = image_index[idx]

print("Design matrix of size : ", xs.shape)
print("Label matrix of size : ", ys.shape)

### Training a Classifier and Model Selection
We computed the features, let's now train a classifier! Our data is very strongly correlated image-by-image. Therefore, to do proper model selection, we can't simply split the data in an IID way. We need to do something a bit smarter. We will do leave-one-out, and consecutively exclude each image as a validation set.
We'll try 8 values of C ranging from 1e-7 to 1e-0.

For each value of C and each image, we plot the precision-recall curve of the classifier, and report the mAP (area under the PR curve).

In [None]:
cs = np.logspace(-7, 0, 8)
scores = np.zeros((n_images, len(cs)))

for i in range(n_images):
    # We use leave-one-out so train will be all but image i, val will be image i
    print('validation using image_{:02d}.jpg'.format(i+1))
    train_selection = image_index != float(i)
    fold_x = xs[train_selection].numpy()
    fold_y = (ys[train_selection] > 0).long().numpy()
    val_x = xs[~train_selection].numpy()
    val_y = (ys[~train_selection] > 0).long().numpy()

    plt.figure()
    for j, c in enumerate(cs):
        print("training logisitic regression with C={:.2e}".format(c))
        clf = LogisticRegression(random_state=0, C=c, max_iter=10000).fit(fold_x, fold_y)
        output = clf.predict_proba(val_x)
        precision, recall, thresholds = precision_recall_curve(val_y, output[:, 1])
        s = average_precision_score(val_y, output[:, 1])
        scores[i, j] = s
        plt.plot(recall, precision, label='C={:.1e} AP={:.1f}'.format(c, s*100))

    plt.grid()
    plt.xlabel('recall')
    plt.title('image_{:02d}.jpg'.format(i+1))
    plt.ylabel('precision')
    plt.axis([0, 1, 0, 1])
    plt.legend()
    plt.show()


### Choosing the Best C
Now, let's have a look at which value of C works best on average. To this end we will plot the average mAP across all validation images.

In [None]:
plt.figure(figsize=(3, 2), dpi=300)
plt.rcParams.update({
    "xtick.labelsize": 5,
    "ytick.labelsize": 5,
    "axes.labelsize": 5,
})
plt.plot(scores.mean(axis=0))
plt.xticks(np.arange(len(cs)), ["{:.0e}".format(c) for c in cs])
plt.xlabel('data fit C')
plt.ylabel('average AP')
plt.grid()
plt.show()

### Retraining with the optimal regularization
Given the above, we seem to have a winner: C=0.1.
Let's now train a model using this optimal data-fit value.

In [None]:
clf = LogisticRegression(random_state=0, C=0.1, max_iter=100000, verbose=2).fit(xs.numpy(), (ys > 0).long().numpy())

### Test Images and Inference

We have a classifier, now it is time to test it! We will predict the probability of patch being foreground given an image, and then process it with a 3x3 median filter to smooth it out.

In [None]:
test_image_fpath = "https://dl.fbaipublicfiles.com/dinov3/notebooks/foreground_segmentation/test_image.jpg"

def load_image_from_url(url: str) -> Image:
    with urllib.request.urlopen(url) as f:
        return Image.open(f).convert("RGB")


test_image = load_image_from_url(test_image_fpath)
test_image_resized = resize_transform(test_image)
test_image_normalized = TF.normalize(test_image_resized, mean=IMAGENET_MEAN, std=IMAGENET_STD)

with torch.inference_mode():
    with torch.autocast(device_type='cuda', dtype=torch.float32):
        feats = model.get_intermediate_layers(test_image_normalized.unsqueeze(0).cuda(), n=range(n_layers), reshape=True, norm=True)
        x = feats[-1].squeeze().detach().cpu()
        dim = x.shape[0]
        x = x.view(dim, -1).permute(1, 0)

h_patches, w_patches = [int(d / PATCH_SIZE) for d in test_image_resized.shape[1:]]

fg_score = clf.predict_proba(x)[:, 1].reshape(h_patches, w_patches)
fg_score_mf = torch.from_numpy(signal.medfilt2d(fg_score, kernel_size=3))

plt.figure(figsize=(9, 3), dpi=300)
plt.subplot(1, 3, 1)
plt.axis('off')
plt.imshow(test_image_resized.permute(1, 2, 0))
plt.title('input image')
plt.subplot(1, 3, 2)
plt.axis('off')
plt.imshow(fg_score)
plt.title('foreground score')
plt.subplot(1, 3, 3)
plt.axis('off')
plt.imshow(fg_score_mf)
plt.title('+ median filter')
plt.show()

### Saving the PyTorch Model for Future Use
We are nearly done, let's just save a pickle with the classifier.


In [None]:
save_root = '.'
model_path = os.path.join(save_root, "fg_classifier.pkl")
with open(model_path, "wb") as f:
  pickle.dump(clf, f)

# ONNX Model Export
- Export DINOv3 Feature Extractor to ONNX
- Export Classifier to ONNX
- device = "CPU"
- inputs are independent of PyTorch

In [None]:
# Install ONNX library
!pip install onnx
# Install ONNX Runtime
!pip install onnxruntime

In [None]:
import onnx
import onnxruntime as ort

In [None]:
print("\n--- Starting ONNX Model Export ---")

# --- Define PyTorch wrapper for DINOv3 Feature Extractor ---
class DinoV3FeatureExtractor(torch.nn.Module):
    def __init__(self, model, n_layers):
        super().__init__()
        self.model = model
        self.n_layers = n_layers

    def forward(self, x):
        # Extract features from the last layer, normalized, as done during training
        features_list = self.model.get_intermediate_layers(x, n=range(self.n_layers), reshape=True, norm=True)
        last_layer_features = features_list[-1]  # Shape: (B, C, H_patches, W_patches)

        # Reshape for classifier: (B, C, H*W) -> (B, H*W, C)
        B, C, H, W = last_layer_features.shape
        features_reshaped = last_layer_features.view(B, C, -1)
        features_permuted = features_reshaped.permute(0, 2, 1)
        return features_permuted

# --- Define PyTorch wrapper for Logistic Regression Classifier ---
class LogisticRegressionONNX(torch.nn.Module):
    def __init__(self, sklearn_classifier):
        super().__init__()
        # Extract weights and bias from the trained sklearn model
        self.coeffs = torch.nn.Parameter(torch.from_numpy(sklearn_classifier.coef_).float())
        self.intercept = torch.nn.Parameter(torch.from_numpy(sklearn_classifier.intercept_).float())

    def forward(self, x):
        # Input 'x' has shape (B, Num_Patches, Channels)
        # Apply linear transformation: x @ W.T + b
        linear_output = torch.matmul(x, self.coeffs.T) + self.intercept
        # Apply sigmoid to get foreground probability
        probabilities = torch.sigmoid(linear_output)
        # Squeeze the last dimension to get (B, Num_Patches)
        return probabilities.squeeze(-1)

# --- Export DINOv3 Feature Extractor to ONNX ---
onnx_feature_extractor_path = "dinov3_feature_extractor.onnx"
device = "cpu"
onnx_exportable_dino = DinoV3FeatureExtractor(model, n_layers).to(device).eval()
# Dummy input with dynamic axes for variable image sizes
dummy_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE, device=device)

print(f"Exporting DINOv3 feature extractor to {onnx_feature_extractor_path}...")
torch.onnx.export(
    onnx_exportable_dino,
    dummy_input,
    onnx_feature_extractor_path,
    input_names=['input_image'],
    output_names=['patch_features'],
    dynamic_axes={
        'input_image': {2: 'height', 3: 'width'},
        'patch_features': {1: 'num_patches'}
    },
    opset_version=17
)
print("DINOv3 ONNX export complete!")

# --- Export Classifier to ONNX ---
onnx_classifier_path = "fg_classifier.onnx"
onnx_exportable_clf = LogisticRegressionONNX(clf).eval()
# Dummy input matching the output of the feature extractor
num_channels = xs.shape[1]
dummy_features = torch.randn(1, (IMAGE_SIZE // PATCH_SIZE)**2, num_channels)

print(f"\nExporting classifier to {onnx_classifier_path}...")
torch.onnx.export(
    onnx_exportable_clf,
    dummy_features,
    onnx_classifier_path,
    input_names=['patch_features'],
    output_names=['probabilities'],
    dynamic_axes={
        'patch_features': {1: 'num_patches'},
        'probabilities': {1: 'num_patches'}
    },
    opset_version=17
)
print("Classifier ONNX export complete!")

print("\n--- Starting Part 3: PyTorch-Independent ONNX Inference ---")



# --- Define Pre-processing Functions (NumPy/Pillow only) ---
def preprocess_image_numpy(img_pil: Image.Image, image_size: int, patch_size: int):
    """Resizes and normalizes an image using NumPy and Pillow."""
    # Resize image to be divisible by patch size
    w, h = img_pil.size
    h_patches = image_size // patch_size
    w_patches = int((w * image_size) / (h * patch_size))

    # Ensure width is even for some model architectures
    if w_patches % 2 != 0:
      w_patches -= 1

    new_h = h_patches * patch_size
    new_w = w_patches * patch_size
    resized_img = img_pil.resize((new_w, new_h), Image.Resampling.BICUBIC)

    # Convert to NumPy array and normalize
    img_np = np.array(resized_img, dtype=np.float32) / 255.0
    mean = np.array(IMAGENET_MEAN, dtype=np.float32)
    std = np.array(IMAGENET_STD, dtype=np.float32)
    normalized_img = (img_np - mean) / std

    # Transpose from HWC to CHW and add batch dimension (BCHW)
    input_tensor = normalized_img.transpose(2, 0, 1)[np.newaxis, :, :]

    return input_tensor, (h_patches, w_patches), resized_img

# --- Load Test Image ---
test_image_fpath = "https://dl.fbaipublicfiles.com/dinov3/notebooks/foreground_segmentation/test_image.jpg"
with urllib.request.urlopen(test_image_fpath) as f:
    test_image = Image.open(f).convert("RGB")

# --- Preprocess Image ---
input_tensor, patch_dims, resized_image_pil = preprocess_image_numpy(test_image, IMAGE_SIZE, PATCH_SIZE)
h_patches, w_patches = patch_dims
print(f"Test image preprocessed. Patch grid size: {patch_dims}")

# --- Run Inference with ONNX Runtime ---
# Create inference sessions
print("Creating ONNX Runtime inference sessions...")
feature_extractor_session = ort.InferenceSession(onnx_feature_extractor_path)
classifier_session = ort.InferenceSession(onnx_classifier_path)

# Get patch features from DINOv3
print("Running feature extraction...")
feature_inputs = {feature_extractor_session.get_inputs()[0].name: input_tensor}
patch_features = feature_extractor_session.run(None, feature_inputs)[0]

# Get foreground probabilities from classifier
print("Running classification...")
classifier_inputs = {classifier_session.get_inputs()[0].name: patch_features}
fg_probabilities = classifier_session.run(None, classifier_inputs)[0]

# --- Post-process and Visualize Results ---
# Reshape probabilities to match patch grid
fg_score = fg_probabilities.reshape(h_patches, w_patches)

# Apply median filter for smoothing
fg_score_mf = signal.medfilt2d(fg_score, kernel_size=3)

# Display the results
print("Displaying final segmentation results...")
plt.figure(figsize=(12, 4), dpi=150)
plt.subplot(1, 3, 1)
plt.imshow(resized_image_pil)
plt.title('Input Image')
plt.axis('off')

plt.subplot(1, 3, 2)
plt.imshow(fg_score, cmap='viridis')
plt.title('Foreground Score')
plt.axis('off')

plt.subplot(1, 3, 3)
plt.imshow(fg_score_mf, cmap='viridis')
plt.title('+ Median Filter')
plt.axis('off')

plt.tight_layout()
plt.show()

print("\n--- Script finished successfully! ---")