In [1]:
import timm

backbone = timm.create_model('vit_base_patch16_224.dino', pretrained=True, num_classes=0)
print(backbone)

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False

In [2]:
import torch.nn as nn

class DinoViTForCIFAR10(nn.Module):
    def __init__(self, backbone, num_classes=10):
        super().__init__()
        self.backbone = backbone
        feat_dim = backbone.embed_dim  # typically 768
        self.head = nn.Linear(feat_dim, num_classes)

    def forward(self, x):
        features = self.backbone(x)  # Output: (B, 768)
        return self.head(features)   # Output: (B, 10)


In [3]:
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import torch
from torch import optim
import torch.nn.functional as F
from tqdm import tqdm



if __name__ == "__main__":

    transform = transforms.Compose([
        transforms.Resize(224),            # Resize CIFAR-10 (32x32) to ViT expected size (224x224)
        transforms.ToTensor(),             # Convert image to PyTorch tensor [0, 1]
        transforms.Normalize((0.5,)*3, (0.5,)*3)  # Normalize to [-1, 1] (mean=0.5, std=0.5 for each channel)
    ])

    train_ds = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_ds  = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=8, pin_memory=True)
    test_loader  = DataLoader(test_ds, batch_size=16, shuffle=False, num_workers=8, pin_memory=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = DinoViTForCIFAR10(backbone).to(device)

    optimizer = optim.Adam(model.head.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(5):
        model.train()
        total, correct = 0, 0
        for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            imgs, labels = imgs.to(device), labels.to(device)

            logits = model(imgs)
            loss = criterion(logits, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total += labels.size(0)
            correct += (logits.argmax(1) == labels).sum().item()

        print(f"Train Acc: {100*correct/total:.2f}%")


Files already downloaded and verified
Files already downloaded and verified


Epoch 1: 100%|██████████| 3125/3125 [30:00<00:00,  1.74it/s]


Train Acc: 92.09%


Epoch 2: 100%|██████████| 3125/3125 [29:53<00:00,  1.74it/s]


Train Acc: 94.03%


Epoch 3: 100%|██████████| 3125/3125 [29:46<00:00,  1.75it/s]


Train Acc: 94.56%


Epoch 4: 100%|██████████| 3125/3125 [29:44<00:00,  1.75it/s]


Train Acc: 94.89%


Epoch 5: 100%|██████████| 3125/3125 [29:42<00:00,  1.75it/s]

Train Acc: 95.27%





In [5]:
# Save only model weights (recommended)
torch.save(model.state_dict(), "vit_cifar10_dino.pth")


In [4]:
model.eval()
total, correct = 0, 0
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        logits = model(imgs)
        total += labels.size(0)
        correct += (logits.argmax(1) == labels).sum().item()

print(f"Test Accuracy: {100*correct/total:.2f}%")


Test Accuracy: 93.78%
