[Reference](https://medium.com/@pankaj_pandey/16f5d285c38d)

In [1]:
"""
train_seg_head.py — Minimal training loop on frozen DINOv3 features.
Replace `YourDataset` with ADE20K/Cityscapes/etc. (images, masks).
"""
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoImageProcessor

MODEL_ID = "facebook/dinov3-vitb16-pretrain-lvd1689m"
N_CLASSES = 150
class LinearSegHead(nn.Module):
    def __init__(self, in_ch, n_classes):
        super().__init__()
        self.proj = nn.Conv2d(in_ch, n_classes, 1)
        self.up = nn.Upsample(scale_factor=16, mode="bilinear", align_corners=False)
    def forward(self, fmap): return self.up(self.proj(fmap))
def extract_fmap(model, proc, image):
    inputs = proc(images=image, return_tensors="pt").to(model.device)
    with torch.inference_mode():
        out = model(**inputs)
    num_regs = model.config.num_register_tokens
    grid = out.last_hidden_state[:, 1 + num_regs:, :]  # drop CLS+registers
    B, N, C = grid.shape
    H = W = int(N ** 0.5)
    return grid.reshape(B, H, W, C).permute(0,3,1,2)
def main(train_set, val_set, epochs=10, lr=1e-3):
    proc = AutoImageProcessor.from_pretrained(MODEL_ID)
    model = AutoModel.from_pretrained(MODEL_ID, device_map="auto")
    for p in model.parameters(): p.requires_grad_(False)
    # infer channels once
    x0, _ = train_set[0]
    fmap0 = extract_fmap(model, proc, x0)
    head = LinearSegHead(fmap0.shape[1], N_CLASSES).to(model.device)
    opt = optim.AdamW(head.parameters(), lr=lr)
    loss = nn.CrossEntropyLoss()
    train = DataLoader(train_set, batch_size=4, shuffle=True)
    val = DataLoader(val_set, batch_size=4)
    for ep in range(epochs):
        head.train()
        for img, mask in train:
            fmap = extract_fmap(model, proc, img)
            logits = head(fmap)
            L = loss(logits, mask.to(logits.device).long())
            opt.zero_grad(); L.backward(); opt.step()
        # TODO: add simple mIoU on `val`
    torch.save(head.state_dict(), "dinov3_seg_head.pth")