<a href="https://colab.research.google.com/github/DotBion/introml-dl/blob/main/fldl25_dinov2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import math
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms as T
from torchvision.transforms import InterpolationMode
from PIL import Image
from tqdm.auto import tqdm
import timm
from datasets import load_dataset
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Paths for checkpoints
base_dir = "/content/drive/MyDrive/fldl_ssl"
ckpt_dir = os.path.join(base_dir, "checkpoints_dinov2_hf")
os.makedirs(ckpt_dir, exist_ok=True)

print("Checkpoint dir:", ckpt_dir)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)
torch.backends.cudnn.benchmark = True

Mounted at /content/drive
Checkpoint dir: /content/drive/MyDrive/fldl_ssl/checkpoints_dinov2_hf
Device: cuda


In [None]:
from datasets import load_dataset

ds = load_dataset("tsbpp/fall2025_deeplearning", split="train")  # example
print(ds[0])

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


cc3m_96px_part1.zip:   0%|          | 0.00/521M [00:00<?, ?B/s]

cc3m_96px_part2.zip:   0%|          | 0.00/521M [00:00<?, ?B/s]

cc3m_96px_part3.zip:   0%|          | 0.00/520M [00:00<?, ?B/s]

cc3m_96px_part4.zip:   0%|          | 0.00/521M [00:00<?, ?B/s]

cc3m_96px_part5.zip:   0%|          | 0.00/521M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=96x96 at 0x7CBD42475A60>}


In [None]:

# How many samples to use for training (for Colab)
limit_n = 50_000   # you can increase later (e.g. 100_000 or full)

#ds = load_dataset(hf_dataset_id, split=hf_split)
print("Original HF dataset size:", len(ds))

if limit_n is not None and limit_n < len(ds):
    ds = ds.select(range(limit_n))

print("Using HF dataset size:", len(ds))
print("Example keys:", ds[0].keys())

Original HF dataset size: 500000
Using HF dataset size: 50000
Example keys: dict_keys(['image'])


In [None]:
class MultiCropHFDataset(Dataset):
    def __init__(
        self,
        hf_ds,
        global_crops_scale=(0.4, 1.0),
        local_crops_scale=(0.05, 0.4),
        global_size=224,
        local_size=224,
        local_crops_number=4,  # fewer for Colab
    ):
        self.ds = hf_ds
        self.local_crops_number = local_crops_number

        flip_color = T.Compose(
            [
                T.RandomHorizontalFlip(p=0.5),
                T.RandomApply(
                    [T.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8
                ),
                T.RandomGrayscale(p=0.2),
            ]
        )

        normalize = T.Compose(
            [
                T.ToTensor(),
                T.Normalize(
                    mean=(0.485, 0.456, 0.406),
                    std=(0.229, 0.224, 0.225),
                ),
            ]
        )

        self.global_transfo1 = T.Compose(
            [
                T.RandomResizedCrop(
                    global_size,
                    scale=global_crops_scale,
                    interpolation=InterpolationMode.BICUBIC,
                ),
                flip_color,
                T.GaussianBlur(kernel_size=21, sigma=(0.1, 2.0)),
                normalize,
            ]
        )

        self.global_transfo2 = T.Compose(
            [
                T.RandomResizedCrop(
                    global_size,
                    scale=global_crops_scale,
                    interpolation=InterpolationMode.BICUBIC,
                ),
                flip_color,
                T.GaussianBlur(kernel_size=21, sigma=(0.1, 2.0)),
                T.RandomSolarize(threshold=0.5, p=0.2),
                normalize,
            ]
        )

        self.local_transfo = T.Compose(
            [
                T.RandomResizedCrop(
                    local_size,
                    scale=local_crops_scale,
                    interpolation=InterpolationMode.BICUBIC,
                ),
                flip_color,
                T.GaussianBlur(kernel_size=9, sigma=(0.1, 2.0)),
                normalize,
            ]
        )

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        img = self.ds[idx]["image"]
        if not isinstance(img, Image.Image):
            img = Image.fromarray(img)
        img = img.convert("RGB")

        crops = []
        crops.append(self.global_transfo1(img))
        crops.append(self.global_transfo2(img))
        for _ in range(self.local_crops_number):
            crops.append(self.local_transfo(img))
        # list length = 2 + local_crops_number
        return crops


def collate_multi_crop(batch):
    """
    batch: list of [num_crops] of tensors [3,H,W]
    returns: list of [B,3,H,W] for each crop index
    """
    num_crops = len(batch[0])
    out = []
    for ci in range(num_crops):
        out.append(torch.stack([sample[ci] for sample in batch], dim=0))
    return out


local_crops_number = 4  # fewer for Colab

dataset = MultiCropHFDataset(
    ds,
    local_crops_number=local_crops_number,
)

batch_size = 128   # tune based on memory
num_workers = 4

loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    drop_last=True,
    collate_fn=collate_multi_crop,  # returns list of [B,3,H,W]
)

print("Dataset size:", len(dataset))
print("Batches per epoch:", len(loader))

# quick sanity check
crops_batch = next(iter(loader))
print("Number of crops:", len(crops_batch))
print("Shape of crop[0]:", crops_batch[0].shape)  # [B,3,H,W]


Dataset size: 50000
Batches per epoch: 390
Number of crops: 6
Shape of crop[0]: torch.Size([128, 3, 224, 224])


In [None]:
class DINOHead(nn.Module):
    def __init__(self, in_dim, out_dim=65536, hidden_dim=2048, bottleneck_dim=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, bottleneck_dim),
        )
        self.last_layer = nn.utils.weight_norm(
            nn.Linear(bottleneck_dim, out_dim, bias=False)
        )
        self.last_layer.weight_g.data.fill_(1.0)
        self.last_layer.weight_g.requires_grad = False

    def forward(self, x):
        x = self.mlp(x)
        x = F.normalize(x, dim=-1)
        x = self.last_layer(x)
        return x


def create_dinov2_vits14_student_teacher(out_dim=65536):
    # DINOv2-S architecture from timm, RANDOM init
    backbone = timm.create_model(
        "vit_small_patch14_dinov2",
        pretrained=False,   # IMPORTANT: random weights
        img_size=224,
        num_classes=0,      # feature extractor
    )
    embed_dim = backbone.num_features

    student_backbone = backbone
    teacher_backbone = deepcopy(backbone)

    student_head = DINOHead(embed_dim, out_dim=out_dim)
    teacher_head = DINOHead(embed_dim, out_dim=out_dim)

    student = nn.Sequential(student_backbone, student_head)
    teacher = nn.Sequential(teacher_backbone, teacher_head)

    return student, teacher, embed_dim


class DINOLoss(nn.Module):
    def __init__(
        self,
        out_dim,
        ncrops,
        warmup_teacher_temp,
        teacher_temp,
        warmup_teacher_temp_epochs,
        nepochs,
        student_temp=0.1,
        center_momentum=0.9,
    ):
        super().__init__()
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        self.register_buffer("center", torch.zeros(1, out_dim))

        self.teacher_temp_schedule = torch.cat(
            (
                torch.linspace(
                    warmup_teacher_temp,
                    teacher_temp,
                    warmup_teacher_temp_epochs,
                ),
                torch.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp,
            )
        )
        self.ncrops = ncrops

    def forward(self, student_output, teacher_output, epoch):
        # student_output: list of [B,out_dim] for all crops
        # teacher_output: list of [B,out_dim] for 2 global
        student_out = torch.cat(student_output, dim=0)
        student_out = student_out / self.student_temp
        student_out = student_out.chunk(self.ncrops)

        temp = self.teacher_temp_schedule[epoch]
        teacher_out = torch.cat(teacher_output, dim=0)
        teacher_out = F.softmax((teacher_out - self.center) / temp, dim=-1)
        teacher_out = teacher_out.detach().chunk(2)

        total_loss = 0.0
        n_loss_terms = 0

        for iq, q in enumerate(teacher_out):
            for v in range(len(student_out)):
                if v == iq:
                    continue
                loss = torch.sum(
                    -q * F.log_softmax(student_out[v], dim=-1), dim=-1
                )
                total_loss += loss.mean()
                n_loss_terms += 1

        total_loss /= n_loss_terms

        # update center
        batch_center = torch.mean(teacher_out[0], dim=0, keepdim=True)
        self.center = (
            self.center * self.center_momentum
            + batch_center * (1 - self.center_momentum)
        ).detach()

        return total_loss


def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=10):
    warmup_iters = warmup_epochs * niter_per_ep
    warmup_schedule = torch.linspace(0, base_value, warmup_iters)

    iters = torch.arange(niter_per_ep * (epochs - warmup_epochs), dtype=torch.float32)
    schedule = final_value + 0.5 * (base_value - final_value) * (
        1.0 + torch.cos(math.pi * iters / iters.max())
    )

    schedule = torch.cat((warmup_schedule, schedule))
    return schedule


In [None]:
epochs = 20          # start small
out_dim = 65536      # DINO default
ncrops = 2 + local_crops_number

student, teacher, embed_dim = create_dinov2_vits14_student_teacher(out_dim=out_dim)
student.to(device)
teacher.to(device)

for p in teacher.parameters():
    p.requires_grad = False

n_params = sum(p.numel() for p in student.parameters())
print("Student params:", n_params)

dino_loss = DINOLoss(
    out_dim=out_dim,
    ncrops=ncrops,
    warmup_teacher_temp=0.04,
    teacher_temp=0.07,
    warmup_teacher_temp_epochs=5,
    nepochs=epochs,
    student_temp=0.1,
    center_momentum=0.9,
).to(device)

params_groups = [
    {
        "params": (
            p
            for n, p in student.named_parameters()
            if p.requires_grad and "last_layer" not in n
        ),
        "weight_decay": 0.04,
    },
    {
        "params": (
            p
            for n, p in student.named_parameters()
            if p.requires_grad and "last_layer" in n
        ),
        "weight_decay": 0.0,
    },
]
optimizer = torch.optim.AdamW(params_groups, lr=1e-3)

niter_per_ep = len(loader)
lr_schedule = cosine_scheduler(1e-3, 1e-6, epochs, niter_per_ep, warmup_epochs=5)
wd_schedule = cosine_scheduler(0.04, 0.4, epochs, niter_per_ep)
momentum_schedule = cosine_scheduler(0.996, 1.0, epochs, niter_per_ep)

# ----- resume support -----
start_epoch = 0
global_step = 0
last_ckpt_path = os.path.join(ckpt_dir, "dinov2_vits14_hf_last.pth")

if os.path.exists(last_ckpt_path):
    print("Found checkpoint:", last_ckpt_path)
    ckpt = torch.load(last_ckpt_path, map_location=device)
    student.load_state_dict(ckpt["student"])
    teacher.load_state_dict(ckpt["teacher"])
    optimizer.load_state_dict(ckpt["optimizer"])
    dino_loss.center = ckpt["center"]
    start_epoch = ckpt["epoch"] + 1
    global_step = ckpt.get("global_step", 0)
    print(f"Resuming from epoch {start_epoch}, global_step {global_step}")
else:
    print("No checkpoint found, starting from scratch.")


Student params: 39784576
Found checkpoint: /content/drive/MyDrive/fldl_ssl/checkpoints_dinov2_hf/dinov2_vits14_hf_last.pth
Resuming from epoch 1, global_step 400


In [10]:
save_every = 200  # steps

for epoch in range(start_epoch, epochs):
    student.train()
    teacher.eval()
    running_loss = 0.0

    pbar = tqdm(enumerate(loader), total=len(loader), desc=f"Epoch {epoch+1}/{epochs}")
    for it, crops in pbar:
        # crops: list of [B,3,H,W] for each crop idx
        crops = [c.to(device, non_blocking=True) for c in crops]

        it_global = epoch * len(loader) + it

        # student: all crops
        student_out = [student(x).float() for x in crops]
        # teacher: only 2 global crops
        with torch.no_grad():
            teacher_out = [teacher(x).float() for x in crops[:2]]

        loss = dino_loss(student_out, teacher_out, epoch)

        # schedules
        lr = lr_schedule[it_global]
        wd = wd_schedule[it_global]
        for i, pg in enumerate(optimizer.param_groups):
            pg["lr"] = lr
            if i == 0:
                pg["weight_decay"] = wd

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # EMA teacher
        with torch.no_grad():
            m = momentum_schedule[it_global]
            for ps, pt in zip(student.parameters(), teacher.parameters()):
                pt.data = pt.data * m + ps.data * (1.0 - m)

        running_loss += loss.item()
        global_step += 1

        pbar.set_postfix({"loss": loss.item(), "lr": lr})

        # intermediate checkpoint
        if global_step % save_every == 0:
            ckpt = {
                "epoch": epoch,
                "global_step": global_step,
                "student": student.state_dict(),
                "teacher": teacher.state_dict(),
                "optimizer": optimizer.state_dict(),
                "center": dino_loss.center,
            }
            torch.save(ckpt, last_ckpt_path)
            print(f"\n[checkpoint] Saved intermediate → {last_ckpt_path}")

    avg_loss = running_loss / len(loader)
    print(f"Epoch {epoch+1}: avg loss = {avg_loss:.4f}")

    # end-of-epoch checkpoint
    epoch_ckpt_path = os.path.join(ckpt_dir, f"dinov2_vits14_hf_epoch{epoch+1:03d}.pth")
    ckpt = {
        "epoch": epoch,
        "global_step": global_step,
        "student": student.state_dict(),
        "teacher": teacher.state_dict(),
        "optimizer": optimizer.state_dict(),
        "center": dino_loss.center,
    }
    torch.save(ckpt, epoch_ckpt_path)
    torch.save(ckpt, last_ckpt_path)
    print(f"Saved epoch checkpoint → {epoch_ckpt_path} (and updated last)")


Epoch 2/20:   0%|          | 0/390 [00:00<?, ?it/s]


[checkpoint] Saved intermediate → /content/drive/MyDrive/fldl_ssl/checkpoints_dinov2_hf/dinov2_vits14_hf_last.pth
Epoch 2: avg loss = 2.4307
Saved epoch checkpoint → /content/drive/MyDrive/fldl_ssl/checkpoints_dinov2_hf/dinov2_vits14_hf_epoch002.pth (and updated last)


Epoch 3/20:   0%|          | 0/390 [00:00<?, ?it/s]


[checkpoint] Saved intermediate → /content/drive/MyDrive/fldl_ssl/checkpoints_dinov2_hf/dinov2_vits14_hf_last.pth

[checkpoint] Saved intermediate → /content/drive/MyDrive/fldl_ssl/checkpoints_dinov2_hf/dinov2_vits14_hf_last.pth
Epoch 3: avg loss = 2.3129
Saved epoch checkpoint → /content/drive/MyDrive/fldl_ssl/checkpoints_dinov2_hf/dinov2_vits14_hf_epoch003.pth (and updated last)


Epoch 4/20:   0%|          | 0/390 [00:00<?, ?it/s]

RuntimeError: DataLoader worker (pid(s) 156611, 156612, 156613, 156614) exited unexpectedly