In [None]:
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

# DINOv3 Zero-Shot Classification Example

This notebook demonstrates how to use DINOv3 for zero-shot classification on ImageNet data.

## Overview
- Load ImageNet dataset with subset directories
- Initialize DinoTxt pipeline for zero-shot classification
- Learn from text prompts (class names)
- Perform inference on target images
- Calculate classification accuracy


## Setup and Imports

In [None]:
import timeit
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch

from getiprompt.pipelines.dinotxt import DinoTxtZeroShotClassification
from getiprompt.types import Priors
from getiprompt.utils.constants import DINOv3BackboneSize

## Configuration

Set your parameters here:


In [None]:
# Configuration parameters
data_root = "caltech101"  # Update this path
precision = "bf16"  # Options: "bf16", "fp16", "fp32"

## Load Dataset


In [None]:
# Import dataset
data_root_path = Path(data_root)
gt_labels = []
label_names = []
target_images = []
for path in data_root_path.rglob("*/*.jpg"):
    label_name = path.parent.name
    if label_name not in label_names:
        label_names.append(label_name)
    img = cv2.imread(str(path))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    target_images.append(img)
    gt_labels.append(label_names.index(label_name))

print(f"Dataset loaded with {len(label_names)} classes")
print(f"Total items: {len(target_images)}")
print(f"Example classes: {label_names}")

In [None]:
# Display a few random example images with matplotlib
random_indices = np.random.Generator(np.random.PCG64()).permutation(len(target_images))[:3]
for i in random_indices:
    img = target_images[i]
    label = label_names[gt_labels[i]]
    plt.imshow(img)
    plt.title(f"Example {i + 1}: {label}")
    plt.axis("off")
    plt.show()

## Prepare Target Images and Ground Truth Labels


## Initialize DinoTxt Pipeline


In [None]:
# Initialize DinoTxt pipeline
print("Initializing DinoTxt pipeline...")
dinotxt = DinoTxtZeroShotClassification(precision=precision, backbone_size=DINOv3BackboneSize.LARGE)
print("Pipeline initialized successfully!")

## Learn from Text Prompts

In [None]:
# Learn from text prompts (class names)
print("Learning from text prompts...")
start_time = timeit.default_timer()

dinotxt.learn(
    reference_images=[],
    reference_priors=[Priors(text=dict(enumerate(label_names)))],
)

learn_time = timeit.default_timer() - start_time
print(f"Learning completed in {learn_time:.2f} seconds")

## Perform Inference


In [None]:
# Perform inference on target images
print("Starting inference...")
inference_start_time = timeit.default_timer()

results = dinotxt.infer(target_images=target_images)
pred_labels = [mask.class_ids()[0] for mask in results.masks]

# Convert to tensors
pred_labels = torch.stack(pred_labels).cuda()
gt_labels = torch.tensor(gt_labels).cuda()

inference_time = timeit.default_timer() - inference_start_time
print(f"Inference completed in {inference_time:.2f} seconds")

## Calculate Results


In [None]:
# Calculate zero-shot classification accuracy
accuracy = sum(pred_labels == gt_labels) / len(gt_labels)
total_time = timeit.default_timer() - start_time

print("=" * 50)
print("RESULTS")
print("=" * 50)
print(f"Accuracy: {accuracy:.4f} ({accuracy * 100:.2f}%)")
print(f"Total time: {total_time:.2f} seconds")
print(f"Time per image: {total_time / len(target_images):.4f} seconds")
print(f"Images processed: {len(target_images)}")
print(f"Precision: {precision}")
print("=" * 50)