This example requires the following dependencies to be installed:
pip install lightly

In [1]:
!pip install lightly

Collecting lightly
  Downloading lightly-1.5.22-py3-none-any.whl.metadata (38 kB)
Collecting hydra-core>=1.0.0 (from lightly)
  Downloading hydra_core-1.3.2-py3-none-any.whl.metadata (5.5 kB)
Collecting lightly_utils~=0.0.0 (from lightly)
  Downloading lightly_utils-0.0.2-py3-none-any.whl.metadata (1.4 kB)
Collecting pytorch_lightning>=1.0.4 (from lightly)
  Downloading pytorch_lightning-2.6.1-py3-none-any.whl.metadata (21 kB)
Collecting aenum>=3.1.11 (from lightly)
  Downloading aenum-3.1.16-py3-none-any.whl.metadata (3.8 kB)
Collecting torchmetrics>0.7.0 (from pytorch_lightning>=1.0.4->lightly)
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch_lightning>=1.0.4->lightly)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Downloading lightly-1.5.22-py3-none-any.whl (859 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m859.3/859.3 kB[0m [31m18.0 MB/s[0m eta [36m0:00:00[0m

Note: The model and training settings do not follow the reference settings
from the paper. The settings are chosen such that the example can easily be
run on a small dataset with a single GPU.

In [2]:
import copy

In [3]:
import torch
import torchvision
from torch import nn

In [4]:
from lightly.loss import DINOLoss
from lightly.models.modules import DINOProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.dino_transform import DINOTransform
from lightly.utils.scheduler import cosine_schedule

In [5]:
class DINO(torch.nn.Module):
    def __init__(self, backbone, input_dim):
        super().__init__()
        self.student_backbone = backbone
        self.student_head = DINOProjectionHead(
            input_dim, 512, 64, 2048, freeze_last_layer=1
        )
        self.teacher_backbone = copy.deepcopy(backbone)
        self.teacher_head = DINOProjectionHead(input_dim, 512, 64, 2048)
        deactivate_requires_grad(self.teacher_backbone)
        deactivate_requires_grad(self.teacher_head)

    def forward(self, x):
        y = self.student_backbone(x).flatten(start_dim=1)
        z = self.student_head(y)
        return z

    def forward_teacher(self, x):
        y = self.teacher_backbone(x).flatten(start_dim=1)
        z = self.teacher_head(y)
        return z

In [6]:
resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
input_dim = 512
# instead of a resnet you can also use a vision transformer backbone as in the
# original paper (you might have to reduce the batch size in this case):
# backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vits16', pretrained=False)
# input_dim = backbone.embed_dim

In [7]:
model = DINO(backbone, input_dim)

  WeightNorm.apply(module, name, dim)


In [8]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

DINO(
  (student_backbone): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats

In [9]:
transform = DINOTransform()

In [10]:
# we ignore object detection annotations by setting target_transform to return 0
def target_transform(t):
    return 0

In [14]:
!wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar

--2026-02-03 19:06:13--  http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
Resolving host.robots.ox.ac.uk (host.robots.ox.ac.uk)... 129.67.94.50
Connecting to host.robots.ox.ac.uk (host.robots.ox.ac.uk)|129.67.94.50|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4516 (4.4K) [text/html]
Saving to: ‘VOCtrainval_11-May-2012.tar’


2026-02-03 19:06:14 (447 MB/s) - ‘VOCtrainval_11-May-2012.tar’ saved [4516/4516]



In [18]:
import shutil
shutil.rmtree("datasets/pascal_voc", ignore_errors=True)

dataset = torchvision.datasets.CIFAR10(
    "datasets/cifar10", download=True, transform=transform
)

# dataset = torchvision.datasets.VOCDetection(
#     "datasets/pascal_voc",
#     download=True,
#     transform=transform,
#     target_transform=target_transform,
# )
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")

100%|██████████| 170M/170M [00:06<00:00, 25.1MB/s]


In [19]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)



In [20]:
criterion = DINOLoss(
    output_dim=2048,
    warmup_teacher_temp_epochs=5,
)
# move loss to correct device because it also contains parameters
criterion = criterion.to(device)

In [21]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [22]:
epochs = 10

In [None]:
print("Starting Training")
for epoch in range(epochs):
    total_loss = 0
    momentum_val = cosine_schedule(epoch, epochs, 0.996, 1)
    for batch in dataloader:
        views = batch[0]
        update_momentum(model.student_backbone, model.teacher_backbone, m=momentum_val)
        update_momentum(model.student_head, model.teacher_head, m=momentum_val)
        views = [view.to(device) for view in views]
        global_views = views[:2]
        teacher_out = [model.forward_teacher(view) for view in global_views]
        student_out = [model.forward(view) for view in views]
        loss = criterion(teacher_out, student_out, epoch=epoch)
        total_loss += loss.detach()
        loss.backward()
        # We only cancel gradients of student head.
        model.student_head.cancel_last_layer_gradients(current_epoch=epoch)
        optimizer.step()
        optimizer.zero_grad()

    avg_loss = total_loss / len(dataloader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")

Starting Training


In [None]:
model.eval()

embeddings = []
labels = []

with torch.no_grad():
    for x, y in viz_loader:
        x = x.to(device)

        # IMPORTANT: use STUDENT BACKBONE
        f = model.student_backbone(x)   # [B, 512, 1, 1]
        f = f.flatten(1)                # [B, 512]

        embeddings.append(f.cpu())
        labels.append(y)

embeddings = torch.cat(embeddings, dim=0).numpy()  # [10000, 512]
labels = torch.cat(labels, dim=0).numpy()


In [None]:
from sklearn.decomposition import PCA

pca = PCA(n_components=2)
emb_2d = pca.fit_transform(embeddings)


In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 8))
scatter = plt.scatter(
    emb_2d[:, 0],
    emb_2d[:, 1],
    c=labels,
    cmap="tab10",
    s=5,
    alpha=0.7
)

plt.legend(*scatter.legend_elements(), title="Classes")
plt.title("DINO (Student Backbone) — PCA")
plt.xlabel("PC 1")
plt.ylabel("PC 2")
plt.show()
