In [1]:
import torch
import random
from torch.utils.data import Subset, DataLoader, random_split
from torchvision import datasets, transforms
import timm
import torch.nn as nn


In [2]:


def get_dino_split(fraction=0.1, image_size=224, train_ratio=0.8):
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    full_data = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
    indices = list(range(len(full_data)))
    random.shuffle(indices)
    subset_len = int(fraction * len(full_data))
    subset = Subset(full_data, indices[:subset_len])
    train_len = int(train_ratio * subset_len)
    test_len = subset_len - train_len
    return random_split(subset, [train_len, test_len])



In [3]:


def train_dino_classifier(epochs=3):
    image_size = 224
    train_set, test_set = get_dino_split(fraction=0.1, image_size=image_size)
    train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=2)

    # Load pretrained DINO backbone (no classification head)
    vit = timm.create_model("vit_base_patch16_224.dino", pretrained=True, num_classes=0)
    for param in vit.parameters():
        param.requires_grad = False
    vit.eval()

    classifier = nn.Sequential(
        vit,
        nn.Linear(vit.num_features, 10)
    ).to("cuda" if torch.cuda.is_available() else "cpu")

    optimizer = torch.optim.Adam(classifier[-1].parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    device = next(classifier.parameters()).device

    # Train only the linear head
    classifier.train()
    for epoch in range(epochs):
        correct, total, total_loss = 0, 0, 0
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            logits = classifier(imgs)
            loss = criterion(logits, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total += labels.size(0)
            correct += (logits.argmax(1) == labels).sum().item()
            total_loss += loss.item() * labels.size(0)

        acc = 100. * correct / total
        print(f"[DINO Classifier] Epoch {epoch+1}: Loss={total_loss:.2f}, Accuracy={acc:.2f}%")

    # Evaluation
    classifier.eval()
    correct = 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            logits = classifier(imgs)
            correct += (logits.argmax(1) == labels).sum().item()

    acc = 100. * correct / len(test_set)
    print(f"[DINO Classifier] Final Test Accuracy: {acc:.2f}%")


In [4]:

if __name__ == "__main__":
    train_dino_classifier()


Files already downloaded and verified
[DINO Classifier] Epoch 1: Loss=2740.09, Accuracy=77.88%
[DINO Classifier] Epoch 2: Loss=898.55, Accuracy=92.85%
[DINO Classifier] Epoch 3: Loss=615.13, Accuracy=95.50%
[DINO Classifier] Final Test Accuracy: 91.90%
