# Dataset visualization

This notebook helps you inspect the raw samples and the corresponding transformed outputs for the datasets supported by the repository. Update the configuration section to point to your local dataset copies before executing the visualization cells.

## How to configure a dataset

Use the `dataset_str` helper format of the data loader utilities (`DatasetName:key=value:key=value`). Typical arguments for the built-in datasets are:

- `ImageNet`: `root`, `extra`, `split` (choose from `TRAIN`, `VAL`, `TEST`).
- `ImageNet22k`: `root`, `extra`, `split` (`TRAIN` or `VAL`).
- `ADE20K`: `root`, `split` (`TRAIN` or `VAL`).
- `CocoCaptions`: `root`, `split` (`TRAIN` or `VAL`).
- `NYU`: `root`, `split` (`TRAIN`, `VAL`, or `TEST`).

Additional keyword arguments accepted by a dataset class can be appended in the same `key=value` form. The notebook validates that the provided filesystem paths exist before instantiating a dataset.

## Available transform presets

The registry defined below wraps the transformation builders shipped with the repository. Pick one of the preset names or supply your own callable that accepts `(image, target)` and returns `(image, target)`.

In [None]:
import logging
import os
import random
from dataclasses import dataclass, replace
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image

from dinov3.data import transforms as data_transforms
from dinov3.data.loaders import _parse_dataset_str
from dinov3.eval.depth import transforms as depth_transforms
from dinov3.eval.segmentation import transforms as seg_transforms

logging.getLogger("dinov3").setLevel(logging.WARNING)

try:
    from torchvision import tv_tensors
    TV_TENSOR_TYPES = tuple(
        cls
        for name, cls in tv_tensors.__dict__.items()
        if isinstance(cls, type) and cls.__module__.startswith("torchvision.")
    )
except Exception:
    tv_tensors = None
    TV_TENSOR_TYPES: Tuple[type, ...] = tuple()


In [None]:
@dataclass(frozen=True)
class TransformSpec:
    factory: Callable[..., Callable[[Any, Any], Tuple[Any, Any]]]
    description: str
    denormalize_mean: Optional[Sequence[float]] = None
    denormalize_std: Optional[Sequence[float]] = None
    value_scale: float = 1.0


IMAGENET_MEAN = data_transforms.IMAGENET_DEFAULT_MEAN
IMAGENET_STD = data_transforms.IMAGENET_DEFAULT_STD
SEGMENTATION_MEAN = tuple(mean * 255.0 for mean in IMAGENET_MEAN)
SEGMENTATION_STD = tuple(std * 255.0 for std in IMAGENET_STD)


def _identity_factory(**_: Any) -> Callable[[Any, Any], Tuple[Any, Any]]:
    def _apply(image, target):
        return image, target

    return _apply


def _classification_train_factory(**kwargs: Any) -> Callable[[Any, Any], Tuple[Any, Any]]:
    image_transform = data_transforms.make_classification_train_transform(**kwargs)

    def _apply(image, target):
        return image_transform(image), target

    return _apply


def _classification_eval_factory(**kwargs: Any) -> Callable[[Any, Any], Tuple[Any, Any]]:
    image_transform = data_transforms.make_classification_eval_transform(**kwargs)

    def _apply(image, target):
        return image_transform(image), target

    return _apply


def _segmentation_train_factory(**kwargs: Any) -> Callable[[Any, Any], Tuple[Any, Any]]:
    return seg_transforms.make_segmentation_train_transforms(**kwargs)


def _segmentation_eval_factory(**kwargs: Any) -> Callable[[Any, Any], Tuple[Any, Any]]:
    return seg_transforms.make_segmentation_eval_transforms(**kwargs)


def _depth_train_factory(**kwargs: Any) -> Callable[[Any, Any], Tuple[Any, Any]]:
    return depth_transforms.make_depth_train_transforms(**kwargs)


def _depth_eval_factory(**kwargs: Any) -> Callable[[Any, Any], Tuple[Any, Any]]:
    return depth_transforms.make_depth_eval_transforms(**kwargs)


TRANSFORM_REGISTRY: Dict[str, TransformSpec] = {
    "identity": TransformSpec(
        factory=_identity_factory,
        description="Return the original sample without applying any transform.",
    ),
    "classification_train": TransformSpec(
        factory=_classification_train_factory,
        description="Default classification training pipeline from dinov3.data.transforms.",
        denormalize_mean=IMAGENET_MEAN,
        denormalize_std=IMAGENET_STD,
    ),
    "classification_eval": TransformSpec(
        factory=_classification_eval_factory,
        description="Evaluation-time classification preprocessing (resize + center crop).",
        denormalize_mean=IMAGENET_MEAN,
        denormalize_std=IMAGENET_STD,
    ),
    "segmentation_train": TransformSpec(
        factory=_segmentation_train_factory,
        description="Segmentation training augmentations from dinov3.eval.segmentation.transforms.",
        denormalize_mean=SEGMENTATION_MEAN,
        denormalize_std=SEGMENTATION_STD,
        value_scale=255.0,
    ),
    "segmentation_eval": TransformSpec(
        factory=_segmentation_eval_factory,
        description="Segmentation evaluation preprocessing with optional TTA.",
        denormalize_mean=SEGMENTATION_MEAN,
        denormalize_std=SEGMENTATION_STD,
        value_scale=255.0,
    ),
    "depth_train": TransformSpec(
        factory=_depth_train_factory,
        description="Depth estimation training augmentations from dinov3.eval.depth.transforms.",
        denormalize_mean=IMAGENET_MEAN,
        denormalize_std=IMAGENET_STD,
    ),
    "depth_eval": TransformSpec(
        factory=_depth_eval_factory,
        description="Depth estimation evaluation preprocessing (with optional flips).",
        denormalize_mean=IMAGENET_MEAN,
        denormalize_std=IMAGENET_STD,
    ),
}


def resolve_transform(
    transform_name: str,
    transform_kwargs: Optional[Dict[str, Any]] = None,
    *,
    custom_transform: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None,
    custom_denorm_overrides: Optional[Dict[str, Any]] = None,
) -> Tuple[Callable[[Any, Any], Tuple[Any, Any]], TransformSpec]:
    if custom_transform is not None:
        overrides = custom_denorm_overrides or {}
        spec = TransformSpec(
            factory=lambda **_: custom_transform,
            description="Custom user-supplied transform",
            denormalize_mean=overrides.get("mean"),
            denormalize_std=overrides.get("std"),
            value_scale=overrides.get("value_scale", 1.0),
        )
        return custom_transform, spec

    if transform_name not in TRANSFORM_REGISTRY:
        raise KeyError(f"Unknown transform preset '{transform_name}'. Available keys: {list(TRANSFORM_REGISTRY)}")

    spec = TRANSFORM_REGISTRY[transform_name]
    transform_callable = spec.factory(**(transform_kwargs or {}))
    if custom_denorm_overrides:
        spec = replace(
            spec,
            denormalize_mean=custom_denorm_overrides.get("mean", spec.denormalize_mean),
            denormalize_std=custom_denorm_overrides.get("std", spec.denormalize_std),
            value_scale=custom_denorm_overrides.get("value_scale", spec.value_scale),
        )
    return transform_callable, spec


In [None]:
print("Available transform presets:")
for name, spec in TRANSFORM_REGISTRY.items():
    print(f" - {name}: {spec.description}")


In [None]:
# --- Configuration ---
# Update the dataset string to match your local setup before running the notebook.
DATASET_STR = "ImageNet:root=/path/to/imagenet:extra=/path/to/imagenet_metadata:split=VAL"

# Select one of the preset transform names defined above.
TRANSFORM_NAME = "classification_eval"
TRANSFORM_KWARGS: Dict[str, Any] = {}

# Optional: supply a custom callable transform or override the denormalization parameters.
CUSTOM_TRANSFORM = None  # e.g. lambda image, target: (image, target)
CUSTOM_DENORMALIZE_OVERRIDES = None  # e.g. {"mean": (0.5, 0.5, 0.5), "std": (0.5, 0.5, 0.5), "value_scale": 1.0}

NUM_SAMPLES = 6
RANDOM_SEED = 0
SHOW_TARGET_SUMMARY = True


In [None]:
def parse_dataset_string(dataset_str: str) -> Tuple[type, Dict[str, Any]]:
    dataset_cls, kwargs = _parse_dataset_str(dataset_str)
    return dataset_cls, dict(kwargs)


def validate_dataset_locations(dataset_kwargs: Dict[str, Any]) -> None:
    missing = []
    for key in ("root", "extra"):
        value = dataset_kwargs.get(key)
        if value is None:
            continue
        if isinstance(value, (str, os.PathLike)):
            path_str = os.fspath(value)
            if path_str and not os.path.exists(path_str):
                missing.append(f"{key}={path_str}")
    if missing:
        joined = "\n".join(missing)
        raise FileNotFoundError("The following dataset paths could not be located:
" + joined)


def instantiate_dataset(dataset_cls: type, dataset_kwargs: Dict[str, Any], transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None):
    init_kwargs = dict(dataset_kwargs)
    init_kwargs["transforms"] = transforms
    return dataset_cls(**init_kwargs)


def ensure_list(obj: Any) -> List[Any]:
    if isinstance(obj, (list, tuple)):
        return list(obj)
    return [obj]


def convert_single_image(
    image: Any,
    mean: Optional[Sequence[float]] = None,
    std: Optional[Sequence[float]] = None,
    value_scale: float = 1.0,
):
    if TV_TENSOR_TYPES and isinstance(image, TV_TENSOR_TYPES):
        image = torch.as_tensor(image)

    if isinstance(image, torch.Tensor):
        tensor = image.detach().cpu()
        if tensor.ndim == 4:
            return [
                convert_single_image(t, mean=mean, std=std, value_scale=value_scale)
                for t in tensor
            ]
        if tensor.ndim == 3:
            if mean is not None and std is not None and len(mean) == tensor.shape[0]:
                mean_tensor = torch.tensor(mean, dtype=tensor.dtype).view(-1, 1, 1)
                std_tensor = torch.tensor(std, dtype=tensor.dtype).view(-1, 1, 1)
                tensor = tensor * std_tensor + mean_tensor
            if value_scale:
                tensor = tensor / float(value_scale)
            tensor = tensor.clamp(0.0, 1.0)
            array = tensor.permute(1, 2, 0).numpy()
            if array.shape[2] == 1:
                array = array[:, :, 0]
            return array
        if tensor.ndim == 2:
            array = tensor.numpy()
            if value_scale:
                array = array / float(value_scale)
            return array
        if tensor.ndim == 1:
            return tensor.numpy()

    if isinstance(image, Image.Image):
        return np.array(image)

    if isinstance(image, np.ndarray):
        return image

    try:
        return np.array(image)
    except Exception:
        return None


def prepare_images_for_display(
    image: Any,
    mean: Optional[Sequence[float]] = None,
    std: Optional[Sequence[float]] = None,
    value_scale: float = 1.0,
) -> List[Optional[np.ndarray]]:
    prepared: List[Optional[np.ndarray]] = []
    for element in ensure_list(image):
        converted = convert_single_image(element, mean=mean, std=std, value_scale=value_scale)
        if isinstance(converted, list):
            prepared.extend(converted)
        else:
            prepared.append(converted)
    return prepared or [None]


def summarize_target(target: Any) -> str:
    if TV_TENSOR_TYPES and isinstance(target, TV_TENSOR_TYPES):
        tensor = torch.as_tensor(target)
        return f"{target.__class__.__name__}{tuple(tensor.shape)}"
    if isinstance(target, torch.Tensor):
        return f"Tensor{tuple(target.shape)} dtype={target.dtype}"
    if isinstance(target, (list, tuple)):
        return f"{type(target).__name__}(len={len(target)})"
    if hasattr(target, "shape"):
        return f"{type(target).__name__} shape={getattr(target, 'shape')}"
    return str(target)


def visualize_samples(
    raw_dataset,
    transformed_dataset,
    transform_spec: TransformSpec,
    indices: List[int],
    *,
    show_target_summary: bool = True,
):
    samples = []
    max_transformed = 0
    for idx in indices:
        raw_image, raw_target = raw_dataset[idx]
        transformed_image, transformed_target = transformed_dataset[idx]
        raw_prepared = prepare_images_for_display(raw_image)
        transformed_prepared = prepare_images_for_display(
            transformed_image,
            mean=transform_spec.denormalize_mean,
            std=transform_spec.denormalize_std,
            value_scale=transform_spec.value_scale,
        )
        if not transformed_prepared:
            transformed_prepared = [None]
        max_transformed = max(max_transformed, len(transformed_prepared))
        samples.append((idx, raw_prepared, transformed_prepared, raw_target, transformed_target))

    max_transformed = max(max_transformed, 1)
    n_rows = len(samples)
    n_cols = 1 + max_transformed
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows))
    if n_rows == 1:
        axes = np.expand_dims(axes, axis=0)
    if n_cols == 1:
        axes = np.expand_dims(axes, axis=1)

    for row_idx, (idx, raw_images, transformed_images, raw_target, transformed_target) in enumerate(samples):
        row_axes = axes[row_idx]
        raw_axis = row_axes[0]
        raw_axis.axis("off")
        if raw_images and raw_images[0] is not None:
            raw_axis.imshow(
                raw_images[0],
                cmap="gray" if isinstance(raw_images[0], np.ndarray) and raw_images[0].ndim == 2 else None,
            )
        else:
            raw_axis.text(0.5, 0.5, "No preview", ha="center", va="center")
        raw_axis.set_title(f"Raw (idx={idx})")
        if show_target_summary:
            raw_axis.set_xlabel(summarize_target(raw_target))

        for offset in range(max_transformed):
            axis = row_axes[offset + 1]
            axis.axis("off")
            if offset < len(transformed_images) and transformed_images[offset] is not None:
                image_arr = transformed_images[offset]
                axis.imshow(
                    image_arr,
                    cmap="gray" if isinstance(image_arr, np.ndarray) and image_arr.ndim == 2 else None,
                )
            else:
                axis.text(0.5, 0.5, "No preview", ha="center", va="center")
            suffix = "" if offset == 0 else f" #{offset + 1}"
            axis.set_title(f"Transformed{suffix}")
            if show_target_summary and offset == 0:
                axis.set_xlabel(summarize_target(transformed_target))

    plt.tight_layout()
    return fig


In [None]:
dataset_cls, dataset_kwargs = parse_dataset_string(DATASET_STR)
validate_dataset_locations(dataset_kwargs)
transform_callable, transform_spec = resolve_transform(
    TRANSFORM_NAME,
    TRANSFORM_KWARGS,
    custom_transform=CUSTOM_TRANSFORM,
    custom_denorm_overrides=CUSTOM_DENORMALIZE_OVERRIDES,
)

raw_dataset = instantiate_dataset(dataset_cls, dataset_kwargs, transforms=None)
transformed_dataset = instantiate_dataset(dataset_cls, dataset_kwargs, transforms=transform_callable)

print(f"Dataset class: {dataset_cls.__name__}")
for key, value in dataset_kwargs.items():
    print(f"  {key}: {value}")
print(f"Transform preset: {TRANSFORM_NAME}")
if transform_spec.denormalize_mean is not None:
    print(f"  Denormalize mean: {transform_spec.denormalize_mean}")
if transform_spec.denormalize_std is not None:
    print(f"  Denormalize std: {transform_spec.denormalize_std}")
print(f"  Value scale: {transform_spec.value_scale}")

num_available = len(raw_dataset)
if num_available == 0:
    raise RuntimeError("The dataset is empty; nothing to visualize.")

num_to_show = min(NUM_SAMPLES, num_available)
if num_to_show < NUM_SAMPLES:
    print(f"Requested {NUM_SAMPLES} samples but only {num_available} are available. Showing {num_to_show}.")

indices = random.Random(RANDOM_SEED).sample(range(num_available), k=num_to_show) if num_available >= num_to_show else list(range(num_available))
indices.sort()
print(f"Visualizing indices: {indices}")

fig = visualize_samples(
    raw_dataset,
    transformed_dataset,
    transform_spec,
    indices,
    show_target_summary=SHOW_TARGET_SUMMARY,
)
fig
