In [1]:
import dataclasses
import math
import warnings
from typing import Callable
import os

import lovely_tensors
import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F
import torchvision.transforms as TVT
import torchvision.transforms.functional as TVTF
import tqdm
from omegaconf import OmegaConf
from torch import Tensor, nn
from torchmetrics.classification import MulticlassJaccardIndex

DINOv3_REPO_DIR = "" # Please add here the path to your DINOv3 repository

# Prepare datasets

In [None]:
# Please change the dataset `self.ds` to yours in the __init__ functions of both datasets.

class ZeroShotSegmentationDataset(torch.utils.data.Dataset):
    CLASS_NAMES: tuple[str, ...]
    IGNORE_ZERO_LABEL: bool  # If True, map label 0 to 255 so it's ignored, and shift all other labels by -1
    transform: Callable[[PIL.Image.Image], Tensor]

    def __init__(self, transform: Callable[[PIL.Image.Image], Tensor]) -> None:
        self.transform = transform

    def _mask_to_tensor(self, mask_pil: PIL.Image.Image) -> Tensor:
        mask = torch.from_numpy(np.array(mask_pil)).long()
        if self.IGNORE_ZERO_LABEL:
            mask = torch.where((mask == 0) | (mask == 255), 255, mask - 1)
        return mask

    def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]:
        img, target = self.ds[idx]
        img = self.transform(img)
        target = self._mask_to_tensor(target)
        return img, target

    def __len__(self) -> int:
        return len(self.ds)


class Cityscapes(ZeroShotSegmentationDataset):
    CLASS_NAMES = (
        "road",
        "sidewalk",
        "building",
        "wall",
        "fence",
        "pole",
        "traffic light",
        "traffic sign",
        "vegetation",
        "terrain",
        "sky",
        "person",
        "rider",
        "car",
        "truck",
        "bus",
        "train",
        "motorcycle",
        "bicycle",
    )
    IGNORE_ZERO_LABEL = False

    def __init__(self, transform: Callable[[PIL.Image.Image], Tensor]) -> None:
        super().__init__(transform)
        self.ds = None # Put here "Cityscapes:split=VAL" dataset


class Ade20k(ZeroShotSegmentationDataset):
    CLASS_NAMES = (
        "wall",
        "building",
        "sky",
        "floor",
        "tree",
        "ceiling",
        "road",
        "bed ",
        "windowpane",
        "grass",
        "cabinet",
        "sidewalk",
        "person",
        "earth",
        "door",
        "table",
        "mountain",
        "plant",
        "curtain",
        "chair",
        "car",
        "water",
        "painting",
        "sofa",
        "shelf",
        "house",
        "sea",
        "mirror",
        "rug",
        "field",
        "armchair",
        "seat",
        "fence",
        "desk",
        "rock",
        "wardrobe",
        "lamp",
        "bathtub",
        "railing",
        "cushion",
        "base",
        "box",
        "column",
        "signboard",
        "chest of drawers",
        "counter",
        "sand",
        "sink",
        "skyscraper",
        "fireplace",
        "refrigerator",
        "grandstand",
        "path",
        "stairs",
        "runway",
        "case",
        "pool table",
        "pillow",
        "screen door",
        "stairway",
        "river",
        "bridge",
        "bookcase",
        "blind",
        "coffee table",
        "toilet",
        "flower",
        "book",
        "hill",
        "bench",
        "countertop",
        "stove",
        "palm",
        "kitchen island",
        "computer",
        "swivel chair",
        "boat",
        "bar",
        "arcade machine",
        "hovel",
        "bus",
        "towel",
        "light",
        "truck",
        "tower",
        "chandelier",
        "awning",
        "streetlight",
        "booth",
        "television receiver",
        "airplane",
        "dirt track",
        "apparel",
        "pole",
        "land",
        "bannister",
        "escalator",
        "ottoman",
        "bottle",
        "buffet",
        "poster",
        "stage",
        "van",
        "ship",
        "fountain",
        "conveyer belt",
        "canopy",
        "washer",
        "plaything",
        "swimming pool",
        "stool",
        "barrel",
        "basket",
        "waterfall",
        "tent",
        "bag",
        "minibike",
        "cradle",
        "oven",
        "ball",
        "food",
        "step",
        "tank",
        "trade name",
        "microwave",
        "pot",
        "animal",
        "bicycle",
        "lake",
        "dishwasher",
        "screen",
        "blanket",
        "sculpture",
        "hood",
        "sconce",
        "vase",
        "traffic light",
        "tray",
        "ashcan",
        "fan",
        "pier",
        "crt screen",
        "plate",
        "monitor",
        "bulletin board",
        "shower",
        "radiator",
        "glass",
        "clock",
        "flag",
    )
    IGNORE_ZERO_LABEL = True

    def __init__(self, transform: Callable[[PIL.Image.Image], Tensor]) -> None:
        super().__init__(transform)
        self.ds = None # Put here "ADE20KChallengeData2016:split=VAL" dataset


DATASETS: dict[str, type[ZeroShotSegmentationDataset]] = {
    "cityscapes": Cityscapes,
    "ade20k": Ade20k,
}
NORMALIZE_IMAGENET = TVT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

# Encode image function

In [11]:
def encode_image(model, img: Tensor) -> tuple[Tensor, Tensor]:
    """Extract image features from the backbone and the additional blocks."""
    B, _, H, W = img.shape
    P = model.visual_model.backbone.patch_size # In the case of our DINOv3
    new_H = math.ceil(H / P) * P
    new_W = math.ceil(W / P) * P

    # Stretch image to a multiple of patch size
    if (H, W) != (new_H, new_W):
        img = F.interpolate(img, size=(new_H, new_W), mode="bicubic", align_corners=False)  # [B, 3, H', W']

    B, _, h_i, w_i = img.shape

    backbone_patches = None
    cls_tokens, _, patch_tokens = model.visual_model.get_class_and_patch_tokens(img)
    blocks_patches = (
        patch_tokens.reshape(B, h_i // P, w_i // P, -1).contiguous()
    ) # [1, h, w, D]

    return backbone_patches, blocks_patches


class ShortSideResize(nn.Module):
    def __init__(self, size: int, interpolation: TVT.InterpolationMode) -> None:
        super().__init__()
        self.size = size
        self.interpolation = interpolation

    def forward(self, img: Tensor) -> Tensor:
        _, h, w = TVTF.get_dimensions(img)
        if (w <= h and w == self.size) or (h <= w and h == self.size):
            return img
        if w < h:
            new_w = self.size
            new_h = int(self.size * h / w)
            return TVTF.resize(img, [new_h, new_w], self.interpolation)
        else:
            new_h = self.size
            new_w = int(self.size * w / h)
            return TVTF.resize(img, [new_h, new_w], self.interpolation)

# Functions for prediction in mode whole or sliding window

In [10]:
def predict_whole(model, img: Tensor, text_features: Tensor) -> Tensor:
    # Extract image features from the additional blocks, ignore the backbone features
    _, H, W = img.shape
    _, blocks_feats = encode_image(model, img.unsqueeze(0))  # [1, h, w, D]
    _, h, w, _ = blocks_feats.shape
    blocks_feats = blocks_feats.squeeze(0)  # [h, w, D]

    # Cosine similarity between patch features and text features (already normalized)
    blocks_feats = F.normalize(blocks_feats, p=2, dim=-1)  # [h, w, D]
    cos = torch.einsum("cd,hwd->chw", text_features, blocks_feats)  # [num_classes, h, w]

    # Return low-res cosine similarities, they will be upsampled to the target resolution later
    return cos

def predict_slide(model, img: Tensor, text_features: Tensor, side: int, stride: int) -> Tensor:
    # Iterate over overlapping windows, accumulate predictions at the image resolution
    _, H, W = img.shape
    num_classes, _ = text_features.shape
    probs = torch.zeros([num_classes, H, W], device="cuda")
    counts = torch.zeros([H, W], device="cuda")
    h_grids = max(H - side + stride - 1, 0) // stride + 1
    w_grids = max(W - side + stride - 1, 0) // stride + 1
    for i in range(h_grids):
        for j in range(w_grids):
            y1 = i * stride
            x1 = j * stride
            y2 = min(y1 + side, H)
            x2 = min(x1 + side, W)
            y1 = max(y2 - side, 0)
            x1 = max(x2 - side, 0)

            # Compute cosine similarities for this window, same logic as predict_whole
            img_window = img[:, y1:y2, x1:x2]  # [3, H_win, W_win]
            cos = predict_whole(model, img_window, text_features)  # [num_classes, h, w]

            # Upsample to the window resolution and accumulate "probabilities"
            # NOTE: they aren't real probabilities, just the result of applying softmax to cosine similarities
            cos = F.interpolate(
                cos.unsqueeze(0),
                size=img_window.shape[1:],
                mode="bilinear",
                align_corners=False,
            ).squeeze(0)  # [num_classes, H_win, W_win]
            probs[:, y1:y2, x1:x2] += cos.softmax(dim=0)  # [num_classes, h, w]
            counts[y1:y2, x1:x2] += 1
    probs /= counts

    # Return "probabilities" at the img resolution, they will be upsampled to the target resolution later
    return probs  # [num_classes, H, W]


# Prompt templates

In [4]:
# Reference: https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb
PROMPT_TEMPLATES = (
    "a bad photo of a {0}.",
    "a photo of many {0}.",
    "a sculpture of a {0}.",
    "a photo of the hard to see {0}.",
    "a low resolution photo of the {0}.",
    "a rendering of a {0}.",
    "graffiti of a {0}.",
    "a bad photo of the {0}.",
    "a cropped photo of the {0}.",
    "a tattoo of a {0}.",
    "the embroidered {0}.",
    "a photo of a hard to see {0}.",
    "a bright photo of a {0}.",
    "a photo of a clean {0}.",
    "a photo of a dirty {0}.",
    "a dark photo of the {0}.",
    "a drawing of a {0}.",
    "a photo of my {0}.",
    "the plastic {0}.",
    "a photo of the cool {0}.",
    "a close-up photo of a {0}.",
    "a black and white photo of the {0}.",
    "a painting of the {0}.",
    "a painting of a {0}.",
    "a pixelated photo of the {0}.",
    "a sculpture of the {0}.",
    "a bright photo of the {0}.",
    "a cropped photo of a {0}.",
    "a plastic {0}.",
    "a photo of the dirty {0}.",
    "a jpeg corrupted photo of a {0}.",
    "a blurry photo of the {0}.",
    "a photo of the {0}.",
    "a good photo of the {0}.",
    "a rendering of the {0}.",
    "a {0} in a video game.",
    "a photo of one {0}.",
    "a doodle of a {0}.",
    "a close-up photo of the {0}.",
    "a photo of a {0}.",
    "the origami {0}.",
    "the {0} in a video game.",
    "a sketch of a {0}.",
    "a doodle of the {0}.",
    "a origami {0}.",
    "a low resolution photo of a {0}.",
    "the toy {0}.",
    "a rendition of the {0}.",
    "a photo of the clean {0}.",
    "a photo of a large {0}.",
    "a rendition of a {0}.",
    "a photo of a nice {0}.",
    "a photo of a weird {0}.",
    "a blurry photo of a {0}.",
    "a cartoon {0}.",
    "art of a {0}.",
    "a sketch of the {0}.",
    "a embroidered {0}.",
    "a pixelated photo of a {0}.",
    "itap of the {0}.",
    "a jpeg corrupted photo of the {0}.",
    "a good photo of a {0}.",
    "a plushie {0}.",
    "a photo of the nice {0}.",
    "a photo of the small {0}.",
    "a photo of the weird {0}.",
    "the cartoon {0}.",
    "art of the {0}.",
    "a drawing of the {0}.",
    "a photo of the large {0}.",
    "a black and white photo of a {0}.",
    "the plushie {0}.",
    "a dark photo of a {0}.",
    "itap of a {0}.",
    "graffiti of the {0}.",
    "a toy {0}.",
    "itap of my {0}.",
    "a photo of a cool {0}.",
    "a photo of a small {0}.",
    "a tattoo of the {0}.",
)

# Load model

In [6]:
# Load the model
import sys
sys.path.append(DINOv3_REPO_DIR)

from dinov3.hub.dinotxt import dinov3_vitl16_dinotxt_tet1280d20h24l
model, tokenizer = dinov3_vitl16_dinotxt_tet1280d20h24l()
model.to("cuda", non_blocking=True)
model.eval()
tokenizer = tokenizer.tokenize

# Configuration

In [None]:
@dataclasses.dataclass
class Configuration:
    dataset: str = "cityscapes" # cityscapes, ade20k

    mode: str = "slide"  # whole (whole image), slide (sliding window inference)
    resize: int = 512  # Short side of the input images

    # Only used for mode=slide
    side: int = 384
    stride: int = 192

# Local setup
lovely_tensors.monkey_patch()
warnings.filterwarnings("ignore", message="xFormers")
cfg: Configuration = OmegaConf.to_object(
    OmegaConf.structured(Configuration),
)
print(f"Configuration:\n{OmegaConf.to_yaml(cfg)}")

# Inference

In [None]:
# Load dataset
transform = TVT.Compose(
    [
        ShortSideResize(cfg.resize, TVT.InterpolationMode.BICUBIC),
        TVT.ToTensor(),
        NORMALIZE_IMAGENET,
    ]
)
dataset = DATASETS[cfg.dataset](transform)
class_names = dataset.CLASS_NAMES
print(f"Dataset: {len(dataset)} images, {len(class_names)} classes")
dataloder = torch.utils.data.DataLoader(
    dataset,
    batch_size=None, # TODO Adapt
    num_workers=1,
    shuffle=False,
    pin_memory=True,
    multiprocessing_context="spawn",
)

# Prepare text features: prompts x class names
text_feats = []
for class_name in tqdm.tqdm(class_names, desc="Class names", unit="name", ncols=0):
    text = [template.format(class_name) for template in PROMPT_TEMPLATES]
    tokens = tokenizer(text).to("cuda", non_blocking=True)
    feats = model.encode_text(tokens)  # [num_prompts, 2D]
    feats = feats[:, feats.shape[1] // 2 :]  # The 1st half of the features corresponds to the CLS token, drop it
    feats = F.normalize(feats, p=2, dim=-1)  # Normalize each text embedding
    feats = feats.mean(dim=0)  # Average over all prompt embeddings per class
    feats = F.normalize(feats, p=2, dim=-1)  # Normalize again
    text_feats.append(feats)
text_feats = torch.stack(text_feats)  # [num_classes, D]
print(f"Text features: {text_feats}")

# Loop over dataset, perform segmentation and compute metrics
miou = MulticlassJaccardIndex(len(class_names), average="macro", ignore_index=255).to("cuda")
for idx, (img, target) in enumerate(tqdm.tqdm(dataloder, desc="Segmentation", unit="img", ncols=0)):
    _, H, W = img.shape
    H_target, W_target = target.shape
    img = img.to("cuda", non_blocking=True)  # [3, H, W]
    target = target.to("cuda", non_blocking=True)  # [H_target, W_target]
    if idx == 0:
        tqdm.tqdm.write(f"Image:  {img}")
        tqdm.tqdm.write(f"Target: {target}")

    if cfg.mode == "whole":
        pred = predict_whole(model, img, text_feats)  # [num_classes, H, W]
    elif cfg.mode == "slide":
        pred = predict_slide(model, img, text_feats, cfg.side, cfg.stride)  # [num_classes, H, W]
    else:
        raise ValueError(f"Unknown mode {cfg.mode}")
    if idx == 0:
        tqdm.tqdm.write(f"Pred:   {pred}")

    # Interpolate to the target resolution and take argmax
    pred = F.interpolate(pred.unsqueeze(0), size=(H_target, W_target), mode="bilinear", align_corners=False)
    pred = pred.squeeze(0).argmax(dim=0)  # [H_target, W_target]
    miou.update(pred.unsqueeze(0), target.unsqueeze(0))

# Compute metrics
print(f"Configuration {cfg}")
print(f"Segmentation mIoU: {100 * miou.compute().item()}")

data = [['mIoU'], [100 * miou.compute().item()]]