This example requires the following dependencies to be installed:
pip install lightly

In [None]:
# !pip install lightly
!export CUDA_VISIBLE_DEVICES=3

Note: The model and training settings do not follow the reference settings
from the paper. The settings are chosen such that the example can easily be
run on a small dataset with a single GPU.

In [None]:
import copy

In [None]:
import torch
import torchvision
from torch import nn

In [None]:
from lightly.loss import DINOLoss
from lightly.models.modules import DINOProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.dino_transform import DINOTransform
from lightly.utils.scheduler import cosine_schedule

In [None]:
import glob
import zipfile
from pathlib import Path
from PIL import Image

from torch.utils.data import Dataset
from huggingface_hub import snapshot_download


class RawImageDataset(Dataset):
    """Dataset that loads images directly from raw files."""

    def __init__(self, root_dir, transform=None, image_extensions=None):
        self.root_dir = Path(root_dir)
        self.transform = transform

        if image_extensions is None:
            image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.JPEG', '*.JPG', '*.PNG']

        # Find all image files
        self.image_paths = []
        print(f"Searching for images in: {self.root_dir}")

        for pattern in image_extensions:
            found = glob.glob(str(self.root_dir / '**' / pattern), recursive=True)
            self.image_paths.extend(found)
            if found:
                print(f"  Found {len(found)} {pattern} files")

        self.image_paths.sort()
        print(f"Total images found: {len(self.image_paths)}")

        if len(self.image_paths) == 0:
            print("\nWarning: No images found. Directory structure (first 20 items):")
            for item in sorted(self.root_dir.rglob('*'))[:20]:
                print(f"  {item}")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # import pdb
        # pdb.set_trace()
        img_path = self.image_paths[idx]

        try:
            img = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            img = Image.new('RGB', (96, 96), color='black')

        if self.transform:
            img = self.transform(img)
            # import pdb; pdb.set_trace()
            # print(img.shape,"old image shape")
            # img = img[0]
            # print(img.shape,"new image shape")

        return img


def download_and_extract_dataset(repo_id, cache_dir=None, max_workers=4):
    """Download and extract dataset from HuggingFace."""

    print(f"Downloading dataset from {repo_id}...")

    try:
        local_dir = snapshot_download(
            repo_id=repo_id,
            repo_type="dataset",
            cache_dir=cache_dir,
            max_workers=max_workers,
            resume_download=True,
        )
        print(f"Dataset downloaded to: {local_dir}")
    except Exception as e:
        print(f"Error during download: {e}")
        print("Retrying with single worker...")
        local_dir = snapshot_download(
            repo_id=repo_id,
            repo_type="dataset",
            cache_dir=cache_dir,
            max_workers=1,
            resume_download=True,
        )
        print(f"Dataset downloaded to: {local_dir}")

    # Extract zip files if present
    local_path = Path(local_dir)
    zip_files = list(local_path.glob('*.zip'))

    if zip_files:
        print(f"\nFound {len(zip_files)} zip files. Extracting...")
        extract_dir = local_path / 'extracted'
        extract_dir.mkdir(exist_ok=True)

        # for zip_file in zip_files:
        #     print(f"  Extracting {zip_file.name}...")
        #     try:
        #         with zipfile.ZipFile(zip_file, 'r') as zf:
        #             zf.extractall(extract_dir)
        #         print("    ✓ Extracted successfully")
        #     except Exception as e:
        #         print(f"    ✗ Error: {e}")

        return extract_dir
    else:
        print("No zip files found, using directory as-is")
        return local_path



In [None]:
# Download and extract dataset
# data_dir = download_and_extract_dataset(
#     repo_id="tsbpp/fall2025_deeplearning",
#     cache_dir=None,
#     max_workers=4
# )
data_dir = Path('./data/devel')
# Create transform
transform = DINOTransform()
# transform = get_mae_transform()

# Create dataset
dataset = RawImageDataset(data_dir, transform=transform)
print(f"\nDataset ready with {len(dataset)} images")

In [None]:
class DINO(torch.nn.Module):
    def __init__(self, backbone, input_dim):
        super().__init__()
        self.student_backbone = backbone
        self.student_head = DINOProjectionHead(
            input_dim, 512, 64, 2048, freeze_last_layer=30
        )
        self.teacher_backbone = copy.deepcopy(backbone)
        self.teacher_head = DINOProjectionHead(input_dim, 512, 64, 2048)
        deactivate_requires_grad(self.teacher_backbone)
        deactivate_requires_grad(self.teacher_head)

    def forward(self, x):
        y = self.student_backbone(x).flatten(start_dim=1)
        z = self.student_head(y)
        return z

    def forward_teacher(self, x):
        y = self.teacher_backbone(x).flatten(start_dim=1)
        z = self.teacher_head(y)
        return z

In [None]:
resnet = torchvision.models.resnet34()
backbone = nn.Sequential(*list(resnet.children())[:-1])
input_dim = 512
# instead of a resnet you can also use a vision transformer backbone as in the
# original paper (you might have to reduce the batch size in this case):
# backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vits16', pretrained=False)
# input_dim = backbone.embed_dim

In [None]:
model = DINO(backbone, input_dim)

In [None]:
device = "cuda:3" if torch.cuda.is_available() else "cpu"
model.to(device)

In [None]:
print("Using device:", device)
if device == "cuda":
    print("Current device index:", torch.cuda.current_device())
    print("Device name:", torch.cuda.get_device_name(0))

In [None]:
def count_params(module):
    return sum(p.numel() for p in module.parameters())

total_params = count_params(model)
student_backbone_params = count_params(model.student_backbone)
student_head_params = count_params(model.student_head)
teacher_backbone_params = count_params(model.teacher_backbone)
teacher_head_params = count_params(model.teacher_head)

student_total = student_backbone_params + student_head_params
teacher_total = teacher_backbone_params + teacher_head_params

print(f"Total params (student + teacher + heads): {total_params:,}")
print(f"  Student backbone: {student_backbone_params:,}")
print(f"  Student head:     {student_head_params:,}")
print(f"  Student TOTAL:    {student_total:,}")
print(f"  Teacher backbone: {teacher_backbone_params:,}")
print(f"  Teacher head:     {teacher_head_params:,}")
print(f"  Teacher TOTAL:    {teacher_total:,}")


In [None]:
transform = DINOTransform()

In [None]:
# we ignore object detection annotations by setting target_transform to return 0
def target_transform(t):
    return 0

In [None]:
# dataset = torchvision.datasets.VOCDetection(
#     "datasets/pascal_voc",
#     download=True,
#     transform=transform,
#     target_transform=target_transform,
# )

# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")

In [None]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=128,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

In [None]:
# criterion = DINOLoss(
#     output_dim=2048,
#     warmup_teacher_temp_epochs=5,
# )
# # move loss to correct device because it also contains parameters
# criterion = criterion.to(device)
criterion = DINOLoss(
    output_dim=2048,
    warmup_teacher_temp=0.08,          # start higher
    teacher_temp=0.04,                 # end not too sharp
    warmup_teacher_temp_epochs=10,     # warm up longer
    student_temp=0.1,
    center_momentum=0.9,               # keep default-ish EMA on center
).to(device)

global_batch_size = 256
base_lr0 = 5e-4
base_lr = base_lr0 * (global_batch_size / 256)

In [None]:
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=3e-4, #TODO: update to use base_lr
    weight_decay=1e-4,
    betas=(0.9, 0.95),
)

In [None]:
warmup_epochs = 10
min_lr = 1e-6

def cosine_lr(epoch):
    if epoch < warmup_epochs:
        return (epoch + 1) / warmup_epochs
    t = (epoch - warmup_epochs) / max(1, (epochs - warmup_epochs))
    return min_lr / base_lr + 0.5 * (1 + math.cos(math.pi * t)) * (1 - min_lr / base_lr)

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, cosine_lr)

In [None]:
def teacher_entropy(logits):
    # logits: (B, C)
    probs = torch.softmax(logits, dim=-1)
    return -(probs * (probs + 1e-8).log()).sum(dim=-1).mean()




In [None]:
epochs = 250

In [None]:
import torch
from pathlib import Path
import wandb

# ---------- Drive setup ----------
# try:
#     from google.colab import drive
#     drive.mount('/content/drive')
#     DRIVE_ROOT = Path("/content/drive/MyDrive")
#     IS_COLAB = True
#     print("✓ Running on Colab, Drive mounted.")
# except Exception:
#     DRIVE_ROOT = Path("./saved_models")
#     IS_COLAB = False
#     print("⚠️ Not on Colab, using local folder ./saved_models")

In [None]:
# ---------- wandb init ----------
# ---------- Project / save dir ----------
PROJECT_NAME = "dino-v1"  # wandb project AND folder name
DRIVE_ROOT = "outputs"
save_dir = Path(DRIVE_ROOT) / Path(PROJECT_NAME)
save_dir.mkdir(parents=True, exist_ok=True)

wandb.init(
    entity="lquan9",
    project=PROJECT_NAME,
    name="dino-resnet34-run-1",      # change run name if you like
)

print("Starting Training")
import time
from tqdm import tqdm
global_step = 0
step_start = time.time()
for epoch in range(epochs):
    total_loss = 0
    momentum_val = cosine_schedule(epoch, epochs, 0.996, 1)

    for views in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):   # views is that list you just inspected
        # EMA update for teacher
        update_momentum(model.student_backbone, model.teacher_backbone, m=momentum_val)
        update_momentum(model.student_head, model.teacher_head, m=momentum_val)

        # move all crops to GPU
        views = [v.to(device) for v in views]

        # first two are global crops for the teacher
        global_views = views[:2]

        # teacher only on global crops
        teacher_out = [model.forward_teacher(v) for v in global_views]

        # Inside training loop, after computing teacher_out
        with torch.no_grad():
        # teacher_out is a list of tensors for the two global crops, same shape
            t_logits = teacher_out[0]  # (B, 2048)
            ent = teacher_entropy(t_logits)
            if global_step % 100 == 0:
                wandb.log({"teacher_entropy": ent.item(), "step": global_step})

        # student on all crops (global + local)
        student_out = [model.forward(v) for v in views]

        loss = criterion(teacher_out, student_out, epoch=epoch)
        total_loss += loss.detach()

        # optimizer.zero_grad()
        loss.backward()
        
        # freeze_epochs = 30  # instead of relying on default 1 epoch
        
        model.student_head.cancel_last_layer_gradients(current_epoch=epoch)
        
        optimizer.step()
        optimizer.zero_grad()

    #scheduler.step() # TODO: add scheduler
        # ---- wandb STEP LOGGING ----
        wandb.log(
            {
                "loss/step": loss.item(),
                "time/step_sec": time.time() - step_start,
                "step": global_step,
                "epoch": epoch,
            },
            step=global_step,
        )

        global_step += 1
        step_start = time.time()

    avg_loss = total_loss / len(dataloader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")

    # ---- wandb logging ----
    wandb.log({
        "loss/train": avg_loss,
        "epoch": epoch,
    })

    # ---- Save checkpoint to Drive (always same filename) ----
    ckpt_path = save_dir / f"{PROJECT_NAME}_latest.pt"
    torch.save(
        {
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "avg_loss": avg_loss,
        },
        ckpt_path,
    )
    print(f"✓ Saved checkpoint: {ckpt_path}")

In [None]:
views = next(iter(dataloader))
print(type(views), len(views))
for i, v in enumerate(views):
    print(i, v.shape)