This example requires the following dependencies to be installed:
pip install lightly

In [2]:
# !pip install lightly

Note: The model and training settings do not follow the reference settings
from the paper. The settings are chosen such that the example can easily be
run on a small dataset with a single GPU.

In [3]:
import copy

In [4]:
import torch
import torchvision
from torch import nn

In [5]:
from lightly.loss import DINOLoss
from lightly.models.modules import DINOProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.dino_transform import DINOTransform
from lightly.utils.scheduler import cosine_schedule

In [20]:
import glob
import zipfile
from pathlib import Path
from PIL import Image

from torch.utils.data import Dataset
from huggingface_hub import snapshot_download


class RawImageDataset(Dataset):
    """Dataset that loads images directly from raw files."""

    def __init__(self, root_dir, transform=None, image_extensions=None):
        self.root_dir = Path(root_dir)
        self.transform = transform

        if image_extensions is None:
            image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.JPEG', '*.JPG', '*.PNG']

        # Find all image files
        self.image_paths = []
        print(f"Searching for images in: {self.root_dir}")

        for pattern in image_extensions:
            found = glob.glob(str(self.root_dir / '**' / pattern), recursive=True)
            self.image_paths.extend(found)
            if found:
                print(f"  Found {len(found)} {pattern} files")

        self.image_paths.sort()
        print(f"Total images found: {len(self.image_paths)}")

        if len(self.image_paths) == 0:
            print("\nWarning: No images found. Directory structure (first 20 items):")
            for item in sorted(self.root_dir.rglob('*'))[:20]:
                print(f"  {item}")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # import pdb
        # pdb.set_trace()
        img_path = self.image_paths[idx]

        try:
            img = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            img = Image.new('RGB', (96, 96), color='black')

        if self.transform:
            img = self.transform(img)
            # import pdb; pdb.set_trace()
            # print(img.shape,"old image shape")
            # img = img[0]
            # print(img.shape,"new image shape")

        return img


def download_and_extract_dataset(repo_id, cache_dir=None, max_workers=4):
    """Download and extract dataset from HuggingFace."""

    print(f"Downloading dataset from {repo_id}...")

    try:
        local_dir = snapshot_download(
            repo_id=repo_id,
            repo_type="dataset",
            cache_dir=cache_dir,
            max_workers=max_workers,
            resume_download=True,
        )
        print(f"Dataset downloaded to: {local_dir}")
    except Exception as e:
        print(f"Error during download: {e}")
        print("Retrying with single worker...")
        local_dir = snapshot_download(
            repo_id=repo_id,
            repo_type="dataset",
            cache_dir=cache_dir,
            max_workers=1,
            resume_download=True,
        )
        print(f"Dataset downloaded to: {local_dir}")

    # Extract zip files if present
    local_path = Path(local_dir)
    zip_files = list(local_path.glob('*.zip'))

    if zip_files:
        print(f"\nFound {len(zip_files)} zip files. Extracting...")
        extract_dir = local_path / 'extracted'
        extract_dir.mkdir(exist_ok=True)

        # for zip_file in zip_files:
        #     print(f"  Extracting {zip_file.name}...")
        #     try:
        #         with zipfile.ZipFile(zip_file, 'r') as zf:
        #             zf.extractall(extract_dir)
        #         print("    ✓ Extracted successfully")
        #     except Exception as e:
        #         print(f"    ✗ Error: {e}")

        return extract_dir
    else:
        print("No zip files found, using directory as-is")
        return local_path



In [21]:
# Download and extract dataset
data_dir = download_and_extract_dataset(
    repo_id="tsbpp/fall2025_deeplearning",
    cache_dir=None,
    max_workers=4
)

# Create transform
transform = DINOTransform()
# transform = get_mae_transform()

# Create dataset
dataset = RawImageDataset(data_dir, transform=transform)
print(f"\nDataset ready with {len(dataset)} images")

Downloading dataset from tsbpp/fall2025_deeplearning...




Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

Dataset downloaded to: /root/.cache/huggingface/hub/datasets--tsbpp--fall2025_deeplearning/snapshots/7b14dd4385d982457822e8e96c5081a30da146d8

Found 5 zip files. Extracting...
Searching for images in: /root/.cache/huggingface/hub/datasets--tsbpp--fall2025_deeplearning/snapshots/7b14dd4385d982457822e8e96c5081a30da146d8/extracted
  Found 500000 *.jpg files
Total images found: 500000

Dataset ready with 500000 images


In [22]:
class DINO(torch.nn.Module):
    def __init__(self, backbone, input_dim):
        super().__init__()
        self.student_backbone = backbone
        self.student_head = DINOProjectionHead(
            input_dim, 512, 64, 2048, freeze_last_layer=1
        )
        self.teacher_backbone = copy.deepcopy(backbone)
        self.teacher_head = DINOProjectionHead(input_dim, 512, 64, 2048)
        deactivate_requires_grad(self.teacher_backbone)
        deactivate_requires_grad(self.teacher_head)

    def forward(self, x):
        y = self.student_backbone(x).flatten(start_dim=1)
        z = self.student_head(y)
        return z

    def forward_teacher(self, x):
        y = self.teacher_backbone(x).flatten(start_dim=1)
        z = self.teacher_head(y)
        return z

In [23]:
resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
input_dim = 512
# instead of a resnet you can also use a vision transformer backbone as in the
# original paper (you might have to reduce the batch size in this case):
# backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vits16', pretrained=False)
# input_dim = backbone.embed_dim

In [24]:
model = DINO(backbone, input_dim)

  WeightNorm.apply(module, name, dim)


In [25]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

DINO(
  (student_backbone): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats

In [48]:
def count_params(module):
    return sum(p.numel() for p in module.parameters())

total_params = count_params(model)
student_backbone_params = count_params(model.student_backbone)
student_head_params = count_params(model.student_head)
teacher_backbone_params = count_params(model.teacher_backbone)
teacher_head_params = count_params(model.teacher_head)

student_total = student_backbone_params + student_head_params
teacher_total = teacher_backbone_params + teacher_head_params

print(f"Total params (student + teacher + heads): {total_params:,}")
print(f"  Student backbone: {student_backbone_params:,}")
print(f"  Student head:     {student_head_params:,}")
print(f"  Student TOTAL:    {student_total:,}")
print(f"  Teacher backbone: {teacher_backbone_params:,}")
print(f"  Teacher head:     {teacher_head_params:,}")
print(f"  Teacher TOTAL:    {teacher_total:,}")


Total params (student + teacher + heads): 23,735,552
  Student backbone: 11,176,512
  Student head:     691,264
  Student TOTAL:    11,867,776
  Teacher backbone: 11,176,512
  Teacher head:     691,264
  Teacher TOTAL:    11,867,776


In [26]:
transform = DINOTransform()

In [27]:
# we ignore object detection annotations by setting target_transform to return 0
def target_transform(t):
    return 0

In [28]:
# dataset = torchvision.datasets.VOCDetection(
#     "datasets/pascal_voc",
#     download=True,
#     transform=transform,
#     target_transform=target_transform,
# )

# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")

In [42]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=128,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

In [43]:
criterion = DINOLoss(
    output_dim=2048,
    warmup_teacher_temp_epochs=5,
)
# move loss to correct device because it also contains parameters
criterion = criterion.to(device)

In [44]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [45]:
epochs = 100

In [46]:
import torch
from pathlib import Path
import wandb

# ---------- Drive setup ----------
try:
    from google.colab import drive
    drive.mount('/content/drive')
    DRIVE_ROOT = Path("/content/drive/MyDrive")
    IS_COLAB = True
    print("✓ Running on Colab, Drive mounted.")
except Exception:
    DRIVE_ROOT = Path("./saved_models")
    IS_COLAB = False
    print("⚠️ Not on Colab, using local folder ./saved_models")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
✓ Running on Colab, Drive mounted.


In [None]:
# ---------- wandb init ----------
# ---------- Project / save dir ----------
PROJECT_NAME = "dino-v1"  # wandb project AND folder name
save_dir = DRIVE_ROOT / PROJECT_NAME
save_dir.mkdir(parents=True, exist_ok=True)

wandb.init(
    entity="amogh-gulati-new-york-university",
    project=PROJECT_NAME,
    name="dino-run-1",      # change run name if you like
)

print("Starting Training")
import time
from tqdm import tqdm
global_step = 0
step_start = time.time()
for epoch in range(epochs):
    total_loss = 0
    momentum_val = cosine_schedule(epoch, epochs, 0.996, 1)

    for views in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):   # views is that list you just inspected
        # EMA update for teacher
        update_momentum(model.student_backbone, model.teacher_backbone, m=momentum_val)
        update_momentum(model.student_head, model.teacher_head, m=momentum_val)

        # move all crops to GPU
        views = [v.to(device) for v in views]

        # first two are global crops for the teacher
        global_views = views[:2]

        # teacher only on global crops
        teacher_out = [model.forward_teacher(v) for v in global_views]

        # student on all crops (global + local)
        student_out = [model.forward(v) for v in views]

        loss = criterion(teacher_out, student_out, epoch=epoch)
        total_loss += loss.detach()

        loss.backward()
        model.student_head.cancel_last_layer_gradients(current_epoch=epoch)
        optimizer.step()
        optimizer.zero_grad()

        # ---- wandb STEP LOGGING ----
        wandb.log(
            {
                "loss/step": loss.item(),
                "time/step_sec": time.time() - step_start,
                "step": global_step,
                "epoch": epoch,
            },
            step=global_step,
        )

        global_step += 1
        step_start = time.time()

    avg_loss = total_loss / len(dataloader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")

    # ---- wandb logging ----
    wandb.log({
        "loss/train": avg_loss,
        "epoch": epoch,
    })

    # ---- Save checkpoint to Drive (always same filename) ----
    ckpt_path = save_dir / f"{PROJECT_NAME}_latest.pt"
    torch.save(
        {
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "avg_loss": avg_loss,
        },
        ckpt_path,
    )
    print(f"✓ Saved checkpoint: {ckpt_path}")

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss/step,▅████▇▆▆▆▆▆▆▆▄▄▃▃▄▄▄▃▃▃▃▂▁▃▄▂▂▂▃▃▃▂▃▃▂▃▁
step,▁▁▁▁▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇███
time/step_sec,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,0.0
loss/step,3.72751
step,465.0
time/step_sec,0.78308


Starting Training


Epoch 1/100: 100%|██████████| 3906/3906 [50:50<00:00,  1.28it/s]


epoch: 00, loss: 2.86400
✓ Saved checkpoint: /content/drive/MyDrive/dino-v1/dino-v1_latest.pt


Epoch 2/100: 100%|██████████| 3906/3906 [50:47<00:00,  1.28it/s]


epoch: 01, loss: 3.80231
✓ Saved checkpoint: /content/drive/MyDrive/dino-v1/dino-v1_latest.pt


Epoch 3/100: 100%|██████████| 3906/3906 [50:45<00:00,  1.28it/s]


epoch: 02, loss: 2.94557
✓ Saved checkpoint: /content/drive/MyDrive/dino-v1/dino-v1_latest.pt


Epoch 4/100: 100%|██████████| 3906/3906 [50:43<00:00,  1.28it/s]


epoch: 03, loss: 2.58663
✓ Saved checkpoint: /content/drive/MyDrive/dino-v1/dino-v1_latest.pt


Epoch 5/100: 100%|██████████| 3906/3906 [50:43<00:00,  1.28it/s]


epoch: 04, loss: 2.44541
✓ Saved checkpoint: /content/drive/MyDrive/dino-v1/dino-v1_latest.pt


Epoch 6/100: 100%|██████████| 3906/3906 [50:42<00:00,  1.28it/s]


epoch: 05, loss: 2.35560
✓ Saved checkpoint: /content/drive/MyDrive/dino-v1/dino-v1_latest.pt


Epoch 7/100: 100%|██████████| 3906/3906 [50:42<00:00,  1.28it/s]


epoch: 06, loss: 2.29149
✓ Saved checkpoint: /content/drive/MyDrive/dino-v1/dino-v1_latest.pt


Epoch 8/100: 100%|██████████| 3906/3906 [50:40<00:00,  1.28it/s]


epoch: 07, loss: 2.24075
✓ Saved checkpoint: /content/drive/MyDrive/dino-v1/dino-v1_latest.pt


Epoch 9/100: 100%|██████████| 3906/3906 [50:39<00:00,  1.29it/s]


epoch: 08, loss: 2.19772
✓ Saved checkpoint: /content/drive/MyDrive/dino-v1/dino-v1_latest.pt


Epoch 10/100: 100%|██████████| 3906/3906 [50:41<00:00,  1.28it/s]


epoch: 09, loss: 2.15910
✓ Saved checkpoint: /content/drive/MyDrive/dino-v1/dino-v1_latest.pt


Epoch 11/100: 100%|██████████| 3906/3906 [50:41<00:00,  1.28it/s]


epoch: 10, loss: 2.12437
✓ Saved checkpoint: /content/drive/MyDrive/dino-v1/dino-v1_latest.pt


Epoch 12/100: 100%|██████████| 3906/3906 [50:41<00:00,  1.28it/s]


epoch: 11, loss: 2.09420
✓ Saved checkpoint: /content/drive/MyDrive/dino-v1/dino-v1_latest.pt


Epoch 13/100: 100%|██████████| 3906/3906 [50:44<00:00,  1.28it/s]


epoch: 12, loss: 2.06601
✓ Saved checkpoint: /content/drive/MyDrive/dino-v1/dino-v1_latest.pt


Epoch 14/100: 100%|██████████| 3906/3906 [50:41<00:00,  1.28it/s]


epoch: 13, loss: 2.04341
✓ Saved checkpoint: /content/drive/MyDrive/dino-v1/dino-v1_latest.pt


Epoch 15/100: 100%|██████████| 3906/3906 [50:39<00:00,  1.29it/s]


epoch: 14, loss: 2.02369
✓ Saved checkpoint: /content/drive/MyDrive/dino-v1/dino-v1_latest.pt


Epoch 16/100: 100%|██████████| 3906/3906 [50:35<00:00,  1.29it/s]


epoch: 15, loss: 2.00340
✓ Saved checkpoint: /content/drive/MyDrive/dino-v1/dino-v1_latest.pt


Epoch 17/100: 100%|██████████| 3906/3906 [50:35<00:00,  1.29it/s]


epoch: 16, loss: 1.98550
✓ Saved checkpoint: /content/drive/MyDrive/dino-v1/dino-v1_latest.pt


Epoch 18/100: 100%|██████████| 3906/3906 [50:32<00:00,  1.29it/s]


epoch: 17, loss: 1.96837
✓ Saved checkpoint: /content/drive/MyDrive/dino-v1/dino-v1_latest.pt


Epoch 19/100: 100%|██████████| 3906/3906 [50:34<00:00,  1.29it/s]


epoch: 18, loss: 1.95378
✓ Saved checkpoint: /content/drive/MyDrive/dino-v1/dino-v1_latest.pt


Epoch 20/100: 100%|██████████| 3906/3906 [50:36<00:00,  1.29it/s]


epoch: 19, loss: 1.93910
✓ Saved checkpoint: /content/drive/MyDrive/dino-v1/dino-v1_latest.pt


Epoch 21/100: 100%|██████████| 3906/3906 [50:38<00:00,  1.29it/s]


epoch: 20, loss: 1.92521
✓ Saved checkpoint: /content/drive/MyDrive/dino-v1/dino-v1_latest.pt


Epoch 22/100: 100%|██████████| 3906/3906 [50:37<00:00,  1.29it/s]


epoch: 21, loss: 1.91157
✓ Saved checkpoint: /content/drive/MyDrive/dino-v1/dino-v1_latest.pt


Epoch 23/100: 100%|██████████| 3906/3906 [50:40<00:00,  1.28it/s]


epoch: 22, loss: 1.90027
✓ Saved checkpoint: /content/drive/MyDrive/dino-v1/dino-v1_latest.pt


Epoch 24/100: 100%|██████████| 3906/3906 [50:37<00:00,  1.29it/s]


epoch: 23, loss: 1.88786
✓ Saved checkpoint: /content/drive/MyDrive/dino-v1/dino-v1_latest.pt


Epoch 25/100: 100%|██████████| 3906/3906 [50:39<00:00,  1.29it/s]


epoch: 24, loss: 1.87855
✓ Saved checkpoint: /content/drive/MyDrive/dino-v1/dino-v1_latest.pt


Epoch 26/100:  92%|█████████▏| 3597/3906 [46:42<04:01,  1.28it/s]

In [36]:
views = next(iter(dataloader))
print(type(views), len(views))
for i, v in enumerate(views):
    print(i, v.shape)

<class 'list'> 8
0 torch.Size([64, 3, 224, 224])
1 torch.Size([64, 3, 224, 224])
2 torch.Size([64, 3, 96, 96])
3 torch.Size([64, 3, 96, 96])
4 torch.Size([64, 3, 96, 96])
5 torch.Size([64, 3, 96, 96])
6 torch.Size([64, 3, 96, 96])
7 torch.Size([64, 3, 96, 96])
