In [None]:
import torch
import torch.nn as nn

class DINOv2Extractor(nn.Module):
    """
    Loads DINOv2 ViT-S feature extractor.
    Output: CLS token feature (384-D for ViT-S/14).
    """
    def __init__(self, variant="dinov2_vits14", local_ckpt_path=""):
        super().__init__()
        self.variant = variant

        if local_ckpt_path:
            raise ValueError("Local checkpoint loading depends on how you saved the model. Use torch.hub for simplicity.")
        else:
            self.model = torch.hub.load("facebookresearch/dinov2", variant)
        self.model.eval()

    @torch.no_grad()
    def forward(self, x):
        # x: [B,3,224,224]
        feats = self.model.forward_features(x)
        cls = feats["x_norm_clstoken"]  # [B, dim]
        return cls
