In [1]:
!pip install timm

Collecting timm
  Downloading timm-1.0.15-py3-none-any.whl.metadata (52 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.0/52.0 kB[0m [31m721.2 kB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
Collecting huggingface_hub (from timm)
  Downloading huggingface_hub-0.29.1-py3-none-any.whl.metadata (13 kB)
Collecting safetensors (from timm)
  Downloading safetensors-0.5.2-cp38-abi3-macosx_11_0_arm64.whl.metadata (3.8 kB)
Downloading timm-1.0.15-py3-none-any.whl (2.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading huggingface_hub-0.29.1-py3-none-any.whl (468 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m468.0/468.0 kB[0m [31m18.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading safetensors-0.5.2-cp38-abi3-macosx_11_0_arm64.whl (408 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m408.9/408.9 kB[0m [31m20.4 MB/s[0

In [18]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import timm  # Make sure timm is installed: pip install timm
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# =======================
# 1. Data Preparation
# =======================
# CIFAR-10 dataset; note that we resize images to 224x224 because most pre-trained models (e.g. DINOv2) expect larger inputs.
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # Using ImageNet normalization
                         std=[0.229, 0.224, 0.225])
])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset  = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_loader  = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


In [19]:
# =======================
# 2. Model Setup
# =======================
feature_extractor = torch.hub.load('facebookresearch/dinov2', "dinov2_vitb14")
# No need to call reset_classifier; the head is already Identity.
feature_extractor.eval()

# Freeze the backbone parameters
for param in feature_extractor.parameters():
    param.requires_grad = False

# Determine the feature dimension (e.g. 768 for ViT-B/14)
num_features = 768
num_classes = 10  # CIFAR-10 has 10 classes

# Create a linear classifier on top of the frozen features.
linear_classifier = nn.Linear(num_features, num_classes)
linear_classifier.to(device)


Using cache found in /Users/rkovalch/.cache/torch/hub/facebookresearch_dinov2_main


Linear(in_features=768, out_features=10, bias=True)

In [20]:
# =======================
# 3. Training Setup
# =======================
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(linear_classifier.parameters(), lr=0.01, momentum=0.9)
num_epochs = 10


In [25]:
# =======================
# 4. Training Loop (Linear Probe)
# =======================
from tqdm import tqdm

for epoch in tqdm(range(num_epochs), desc="epochs progress"):
    linear_classifier.train()
    running_loss = 0.0
    total = 0
    correct = 0

    for inputs, labels in tqdm(train_loader, desc="train loader"):
        inputs, labels = inputs.to(device), labels.to(device)

        # Extract features using the frozen DINOv2 backbone.
        with torch.no_grad():
            features_dict = feature_extractor.forward_features(inputs)
            # Extract the class token from the returned dictionary.
            features = features_dict["x_norm_clstoken"]
            # Optionally, if your model returns a 4D tensor and you want to average pool patch tokens,
            # you could modify accordingly, e.g.:
            # features = features_dict["x_norm_patchtokens"]
            # features = features.mean(dim=[2, 3])

        # Forward pass through the linear classifier.
        outputs = linear_classifier(features)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = correct / total * 100
    print(f"Epoch [{epoch+1}/{num_epochs}] - Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%")


epochs progress:   0%|          | 0/10 [00:00<?, ?it/s]
train loader:   0%|          | 0/782 [00:00<?, ?it/s][A
train loader:   0%|          | 1/782 [00:07<1:38:39,  7.58s/it][A
train loader:   0%|          | 2/782 [00:11<1:09:34,  5.35s/it][A
train loader:   0%|          | 3/782 [00:15<1:00:09,  4.63s/it][A
train loader:   1%|          | 4/782 [00:18<55:53,  4.31s/it]  [A
train loader:   1%|          | 5/782 [00:22<53:31,  4.13s/it][A
train loader:   1%|          | 6/782 [00:26<52:00,  4.02s/it][A
train loader:   1%|          | 7/782 [00:30<50:58,  3.95s/it][A
train loader:   1%|          | 8/782 [00:34<50:09,  3.89s/it][A
train loader:   1%|          | 9/782 [00:37<49:40,  3.86s/it][A
train loader:   1%|▏         | 10/782 [00:41<49:15,  3.83s/it][A
train loader:   1%|▏         | 11/782 [00:45<49:08,  3.82s/it][A
train loader:   2%|▏         | 12/782 [00:49<49:02,  3.82s/it][A
train loader:   2%|▏         | 13/782 [00:53<48:51,  3.81s/it][A
train loader:   2%|▏         |

KeyboardInterrupt: 