DINOv2 distilled
(ViT-B/14)

# Installation

In [None]:
# Installation
!pip install -q torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118
!pip install -q transformers
!pip install -q scikit-learn
!pip install -q accelerate
!pip install -q torchmetrics
!pip install -q torch torchvision transformers scikit-learn accelerate

In [None]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from transformers import AutoImageProcessor, AutoModel, AutoModelForImageClassification
from torch.optim.lr_scheduler import ReduceLROnPlateau
from PIL import Image
import numpy as np
from tqdm import tqdm
import random
import os
import torch.optim as optim # import optim
from torch.amp import autocast, GradScaler
import matplotlib.pyplot as plt
from umap import UMAP  # pip install umap-learn
import matplotlib.pyplot as plt

In [None]:
print("CUDA Available:", torch.cuda.is_available())
print("Device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")

In [None]:
# Set random seed to ensure reproducible results
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

# 1. Only training a Linear Classifier

In [None]:
#transform (resize to 224)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # ViT/DINOv2 require
    transforms.ToTensor()
])

In [None]:
# dataset download
full_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

val_ratio = 0.1
val_size = int(len(full_train) * val_ratio)
train_size = len(full_train) - val_size

train_set, val_set = random_split(full_train, [train_size, val_size])

# set vatch size
batch_size = 64

In [None]:
# Load DINOv2 model
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move model to device
model = AutoModel.from_pretrained("facebook/dinov2-base").to(device)
model.eval()

In [None]:
# Characteristic function (applicable to DataLoader)
def extract_embeddings_from_dataset(dataset):
  embeddings = []
  labels = []

  for img, label in tqdm(dataset):
    # Reconstruct the PIL Image from the original image (tensor → PIL)
    img = transforms.ToPILImage()(img)
    inputs = processor(images=img, return_tensors="pt").to(device)
    with torch.no_grad():
      outputs = model(**inputs)
      cls_token = outputs.last_hidden_state[:, 0, :]
      embeddings.append(cls_token.cpu().numpy())
      labels.append(label)

  return np.vstack(embeddings), np.array(labels)

In [None]:
# Extract features
train_embeddings, train_labels = extract_embeddings_from_dataset(train_set)
val_embeddings, val_labels = extract_embeddings_from_dataset(val_set)
test_embeddings, test_labels = extract_embeddings_from_dataset(test_set)

In [None]:
# Training a Linear Classifier
from sklearn.linear_model import LogisticRegression

clf = LogisticRegression(max_iter=1000)
clf.fit(train_embeddings, train_labels)

# Validation set evaluation
val_acc = (clf.predict(val_embeddings) == val_labels).mean()
print("Validation Accuracy:", round(val_acc * 100, 2), "%")

# Test set evaluation
test_acc = (clf.predict(test_embeddings) == test_labels).mean()
print("Test Accuracy:", round(test_acc * 100, 2), "%")


# FINE-TUEN with Scheduler (Selective Fine-Tuning)

In [None]:
#transform （keep in PIL foramt)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Lambda(lambda x: x.convert("RGB"))  # Make sure it is in PIL format
])

In [None]:
# Customize collate_fn to keep the image format as is
def collate_pil(batch):
    images, labels = zip(*batch)
    return list(images), torch.tensor(labels)

In [None]:
random.seed(42)

# dataset download
full_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

val_ratio = 0.1
val_size = int(len(full_train) * val_ratio)
train_size = len(full_train) - val_size

train_set, val_set = random_split(full_train, [train_size, val_size])

In [None]:
# try batch size as 256, and use num_worker to accelerate
train_loader = DataLoader(train_set, batch_size=256, shuffle=True, collate_fn=collate_pil, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=256, shuffle=False, collate_fn=collate_pil, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=256, shuffle=False, collate_fn=collate_pil, num_workers=2, pin_memory=True)

In [None]:
# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Define device here

processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")

model = AutoModelForImageClassification.from_pretrained(
    "facebook/dinov2-base",
    num_labels=10,
    ignore_mismatched_sizes=True
).to(device)

# Freeze selected layers BEFORE defining optimizer
for name, param in model.named_parameters():
    if "encoder.layer" in name:
        layer_num = int(name.split("encoder.layer.")[1].split(".")[0])
        if layer_num < 9:
            param.requires_grad = False

optimizer = optim.AdamW(model.parameters(), lr=5e-5) # use optim.AdamW
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, verbose=True)
loss_fn = nn.CrossEntropyLoss()

# Initialize GradScaler for automatic mixed precision
scaler = GradScaler() # Add this line to initialize GradScaler

In [None]:
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in loader:
            inputs = processor(images=imgs, return_tensors="pt").to(device)
            labels = labels.to(device)
            outputs = model(**inputs)
            preds = outputs.logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return correct / total


In [None]:
best_val_acc = 0.0
no_improvement = 0
early_stop_patience = 5  # Set patience as 5 to prevent early stopping
save_path = "/content/drive/MyDrive/rec_model/dinov2/dinov2_finetuned_cifar10.pth"

train_losses = []
train_accuracies = []
val_accuracies = []
lr_history = []

for epoch in range(40): # Beacause I already had early stopping, I used 40 to try
    model.train()
    # Initialize
    total_loss = 0
    epoch_loss = 0
    correct = 0
    total = 0
    correct_train = 0
    total_train = 0

    for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        inputs = processor(images=imgs, return_tensors="pt").to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        with autocast("cuda"):
            outputs = model(**inputs)
            loss = loss_fn(outputs.logits, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item()

        # Train accuracy tracking
        preds = outputs.logits.argmax(dim=1)
        correct_train += (preds == labels).sum().item()
        total_train += labels.size(0)

    avg_loss = epoch_loss / len(train_loader)
    train_acc = correct_train / total_train
    val_acc = evaluate(model, val_loader)

    # Log values
    train_losses.append(avg_loss)
    train_accuracies.append(train_acc)
    val_accuracies.append(val_acc)
    scheduler.step(val_acc)
    lr_history.append(optimizer.param_groups[0]['lr'])

    print(f"[Epoch {epoch+1}] Avg Loss: {avg_loss:.4f} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")

    current_lr = optimizer.param_groups[0]['lr']
    print(f"[Epoch {epoch+1}] Current LR: {current_lr:.6f}")

    if val_acc > best_val_acc:
        print(f"New best val acc! Saving model to {save_path}")
        best_val_acc = val_acc
        no_improvement = 0
        torch.save(model.state_dict(), save_path)
    else:
        no_improvement += 1
        print(f"No improvement for {no_improvement} epoch(s).")

    if no_improvement >= early_stop_patience:
        print("Early stopping triggered.")
        break

In [None]:
# save model
torch.save(model, "/content/drive/MyDrive/rec_model/dinov2/dinov2_finetuned_entire_model.pt")

In [None]:
plt.plot(train_accuracies, label="Train Accuracy")
plt.plot(val_accuracies, label="Validation Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Training vs Validation Accuracy")
plt.legend()
plt.grid(True)
plt.show()


In [None]:
plt.plot(train_losses, label="Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss Curve")
plt.grid(True)
plt.show()


In [None]:
model.load_state_dict(torch.load(save_path))
test_acc = evaluate(model, test_loader)
print(f"Final Test Accuracy: {test_acc:.4f}")

In [None]:
import pandas as pd

log_df = pd.DataFrame({
    "epoch": list(range(1, len(train_losses)+1)),
    "train_loss": train_losses,
    "train_acc": train_accuracies,
    "val_acc": val_accuracies,
    "lr": lr_history
})

log_df.to_csv("/content/drive/MyDrive/rec_model/dinov2/training_log.csv", index=False)
log_df.head()

# 3.Loading model output results

In [None]:
# Loading model output results
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
model = AutoModelForImageClassification.from_pretrained(
    "facebook/dinov2-base",
    num_labels=10,
    ignore_mismatched_sizes=True
).to(device)

# Load finetuned weight
model.load_state_dict(torch.load("/content/drive/MyDrive/rec_model/dinov2/dinov2_finetuned_cifar10.pth"))
model.eval()


In [None]:
# try batch size as 256, and use num_worker to accelerate
train_loader = DataLoader(train_set, batch_size=256, shuffle=True, collate_fn=collate_pil, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=256, shuffle=False, collate_fn=collate_pil, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=256, shuffle=False, collate_fn=collate_pil, num_workers=2, pin_memory=True)

In [None]:
random.seed(42)

# CIFAR-10 class names
class_names = ["airplane", "automobile", "bird", "cat", "deer",
               "dog", "frog", "horse", "ship", "truck"]

# Get a batch from test set
imgs, labels = next(iter(test_loader))
imgs = list(imgs)
labels = labels.tolist()

# Select 30 random samples
indices = random.sample(range(len(imgs)), 30)
selected_imgs = [imgs[i] for i in indices]
true_labels = [labels[i] for i in indices]

# Inference
inputs = processor(images=selected_imgs, return_tensors="pt").to(device)
with torch.no_grad():
    outputs = model(**inputs)
    preds = outputs.logits.argmax(dim=1).cpu().tolist()

# Plotting
plt.figure(figsize=(20, 6))
for i, img in enumerate(selected_imgs):
    plt.subplot(3, 10, i+1)
    plt.imshow(img)
    plt.title(f'{class_names[true_labels[i]]}→{class_names[preds[i]]}', fontsize=8)
    plt.axis("off")

plt.suptitle("DINOv2 Predictions on Test Images (30 samples)", fontsize=14)
plt.tight_layout()
plt.savefig("/content/drive/MyDrive/rec_model/dinov2/dinov2_pred_30.png", dpi=300)
plt.show()


In [None]:
features = []
labels = []

with torch.no_grad():
    for imgs, lbls in test_loader:
        inputs = processor(images=imgs, return_tensors="pt").to(device)

        # Access the base model's forward method to get hidden states
        # This assumes the AutoModelForImageClassification has an attribute
        # that is the base model (e.g., 'vit' or similar depending on the architecture).
        # For DINOv2, the base model is often named 'vit'.
        # Let's try accessing model.dinov2 or model.vit
        # Based on the model structure, 'dinov2' is likely the correct attribute.
        base_outputs = model.dinov2(**inputs)

        # Now base_outputs should have last_hidden_state
        cls_token = base_outputs.last_hidden_state[:, 0, :]  # CLS token
        features.append(cls_token.cpu())
        labels.extend(lbls.cpu().tolist())

features = torch.cat(features).numpy()
labels = np.array(labels)

print(f"Extracted features shape: {features.shape}")
print(f"Extracted labels shape: {labels.shape}")

In [None]:
# UMAP projection of dinov2
reducer = UMAP(n_components=2, random_state=42)
proj = reducer.fit_transform(features)

plt.figure(figsize=(10, 7))
for class_idx in np.unique(labels):
    idx = labels == class_idx
    plt.scatter(proj[idx, 0], proj[idx, 1], label=class_names[class_idx], s=10)
plt.legend()
plt.title("UMAP Projection of DINOv2 CLS Features")
plt.xlabel("UMAP Feature 1")
plt.ylabel("UMAP Feature 2")
plt.grid(True)
plt.tight_layout()
plt.savefig('/content/drive/MyDrive/rec_model/dinov2/dinov2_umap.png', dpi=300)
plt.show()
