In [None]:
# pip install -U pip
# pip install --index-url https://download.pytorch.org/whl/cu126 torch torchvision
# pip install timm einops
#pip install transformers accelerate
# python linear_probe_cifar10_vit.py --model vit_base_patch16_224 --epochs 10 --batch_train 128
#pip install huggingface_hub[hf_xet]


In [1]:
import torch, torchvision
print(torch.__version__, "| CUDA build:", torch.version.cuda)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0), "| count:", torch.cuda.device_count())


2.8.0+cu126 | CUDA build: 12.6
CUDA available: True
GPU: NVIDIA GeForce RTX 4070 | count: 1


In [1]:
# file: sanity_vit.py
import torch, timm
from torchvision import transforms
from PIL import Image
from urllib.request import urlopen

  from .autonotebook import tqdm as notebook_tqdm


In [5]:

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)
model = timm.create_model("vit_base_patch16_224", pretrained=True).to(device).eval()

tf = transforms.Compose([
    transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225)),
])

img = Image.open(urlopen("https://images.unsplash.com/photo-1518791841217-8f162f1e1131?w=512")).convert("RGB")
x = tf(img).unsqueeze(0).to(device)

with torch.no_grad():
    y = model(x)                   # [1, 1000] logits
print("ok, logits shape:", tuple(y.shape))

device: cuda
ok, logits shape: (1, 1000)


In [2]:
# === ViT on CIFAR-10 with torchvision  ===
# Works in a single notebook cell. Toggle FREEZE_BACKBONE for linear-probe vs finetune.

import time, math, platform, torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

# ----------------- config -----------------
FREEZE_BACKBONE = True          # True = linear-probe (head only), False = full fine-tune
MODEL_NAME = "vit_b_16"         # other options: vit_b_32 (faster), vit_l_16 (heavier)
EPOCHS = 10
BATCH  = 64                     # safe for 12GB with AMP
TEST_B = 256
LR     = 0.1 if FREEZE_BACKBONE else 5e-4
MOM    = 0.9
WD     = 0.0 if FREEZE_BACKBONE else 0.05
SEED   = 42

# ----------------- setup -----------------
torch.manual_seed(SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device={device} | model={MODEL_NAME} | freeze_backbone={FREEZE_BACKBONE}")

if hasattr(torch, "set_float32_matmul_precision"):
    torch.set_float32_matmul_precision("high")

# ----------------- data -----------------
# torchvision ViT expects 224 + ImageNet stats
mean,std=(0.485,0.456,0.406),(0.229,0.224,0.225)
tf_train = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean,std),
])
tf_test = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean,std),
])

train_ds = datasets.CIFAR10("./data", train=True,  download=True, transform=tf_train)
test_ds  = datasets.CIFAR10("./data", train=False, download=True, transform=tf_test)
num_workers = 2 if platform.system()=="Windows" else 4
pin = (device=="cuda")
train_ld = DataLoader(train_ds, batch_size=BATCH, shuffle=True,
                      num_workers=num_workers, pin_memory=pin, persistent_workers=True, prefetch_factor=2)
test_ld  = DataLoader(test_ds,  batch_size=TEST_B, shuffle=False,
                      num_workers=num_workers, pin_memory=pin, persistent_workers=True, prefetch_factor=2)

# ----------------- model -----------------
# Load ImageNet-pretrained weights
weights = models.ViT_B_16_Weights.IMAGENET1K_V1 if MODEL_NAME=="vit_b_16" else None
model = getattr(models, MODEL_NAME)(weights=weights).to(device)

# Replace classifier head to 10 classes
in_feats = model.heads.head.in_features           # torchvision ViT head
model.heads.head = nn.Linear(in_feats, 10).to(device)

# Freeze/unfreeze
if FREEZE_BACKBONE:
    for n,p in model.named_parameters():
        if "heads.head" not in n:
            p.requires_grad_(False)

# ----------------- optim/loss/amp -----------------
params = [p for p in model.parameters() if p.requires_grad]
opt = optim.SGD(params, lr=LR, momentum=MOM, nesterov=True, weight_decay=WD) if FREEZE_BACKBONE \
      else optim.AdamW(params, lr=LR, weight_decay=WD)
sched = None if FREEZE_BACKBONE else optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)
crit = nn.CrossEntropyLoss()

use_amp = (device=="cuda")
try:
    scaler = torch.amp.GradScaler("cuda", enabled=use_amp)
    def autocast(): return torch.amp.autocast("cuda", enabled=use_amp)
except TypeError:
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
    def autocast(): return torch.cuda.amp.autocast(enabled=use_amp)

# ----------------- eval -----------------
@torch.inference_mode()
def evaluate():
    model.eval(); total=correct=0
    with autocast():
        for x,y in test_ld:
            x,y = x.to(device,non_blocking=True), y.to(device,non_blocking=True)
            pred = model(x).argmax(1)
            total += y.size(0); correct += (pred==y).sum().item()
    return 100.0*correct/total

# ----------------- train -----------------
t0=time.time(); num_batches=math.ceil(len(train_ds)/BATCH)
for ep in range(1, EPOCHS+1):
    model.train(); run=0.0
    for i,(x,y) in enumerate(train_ld,1):
        x,y = x.to(device,non_blocking=True), y.to(device,non_blocking=True)
        opt.zero_grad(set_to_none=True)
        with autocast():
            loss = crit(model(x), y)
        scaler.scale(loss).backward()
        scaler.step(opt); scaler.update()
        run += loss.item()
        if i % 20 == 0 or i==num_batches:
            print(f"\rEp {ep:02d} [{i}/{num_batches}] loss={run/i:.4f}", end="")
    print()
    if sched: sched.step()
    acc = evaluate()
    print(f"Epoch {ep:02d} | train_loss={run/num_batches:.4f} | test_acc={acc:.2f}%")
print(f"Done in {time.time()-t0:.1f}s")


device=cuda | model=vit_b_16 | freeze_backbone=True
Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to C:\Users\Pham Huy/.cache\torch\hub\checkpoints\vit_b_16-c867db91.pth


100%|██████████| 330M/330M [00:08<00:00, 40.3MB/s] 


Ep 01 [782/782] loss=0.2750
Epoch 01 | train_loss=0.2750 | test_acc=94.06%
Ep 02 [782/782] loss=0.2141
Epoch 02 | train_loss=0.2141 | test_acc=92.86%
Ep 03 [782/782] loss=0.1924
Epoch 03 | train_loss=0.1924 | test_acc=94.03%
Ep 04 [782/782] loss=0.1770
Epoch 04 | train_loss=0.1770 | test_acc=94.35%
Ep 05 [782/782] loss=0.1639
Epoch 05 | train_loss=0.1639 | test_acc=93.77%
Ep 06 [782/782] loss=0.1541
Epoch 06 | train_loss=0.1541 | test_acc=94.33%
Ep 07 [782/782] loss=0.1543
Epoch 07 | train_loss=0.1543 | test_acc=94.14%
Ep 08 [782/782] loss=0.1480
Epoch 08 | train_loss=0.1480 | test_acc=94.08%
Ep 09 [782/782] loss=0.1454
Epoch 09 | train_loss=0.1454 | test_acc=94.14%
Ep 10 [782/782] loss=0.1427
Epoch 10 | train_loss=0.1427 | test_acc=94.20%
Done in 809.0s


In [3]:
import os, torch
os.makedirs("checkpoints", exist_ok=True)
ckpt_path = "checkpoints/vit_b_16_cifar10_linearprobe.pt"
torch.save({"state_dict": model.state_dict(), "arch": "vit_b_16"}, ckpt_path)
print("Saved to", ckpt_path)


Saved to checkpoints/vit_b_16_cifar10_linearprobe.pt


In [7]:
from torchvision import models
import torch.nn as nn, torch

ckpt = torch.load("checkpoints/vit_b_16_cifar10_linearprobe.pt", map_location="cpu")
m = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1)
m.heads.head = nn.Linear(m.heads.head.in_features, 10)
m.load_state_dict(ckpt["state_dict"], strict=True)
m.eval()
print("ok, model loaded")

ok, model loaded


In [5]:
# If you still have `model`, `device`, and `test_ld` from the last run:
acc = evaluate()
print(f"CIFAR-10 test accuracy: {acc:.2f}%")


CIFAR-10 test accuracy: 94.20%


In [6]:
import torch

@torch.inference_mode()
def evaluate():
    model.eval()
    total = correct = 0
    for x, y in test_ld:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        pred = model(x).argmax(1)
        total += y.size(0)
        correct += (pred == y).sum().item()
    return 100.0 * correct / total

acc = evaluate()
print(f"CIFAR-10 test accuracy: {acc:.2f}%")


CIFAR-10 test accuracy: 94.20%


In [None]:
""" from torchvision import models
weights = models.ViT_B_32_Weights.IMAGENET1K_V1
model = models.vit_b_32(weights=weights).to(device)
model.heads.head = nn.Linear(model.heads.head.in_features, 10).to(device)"""

# Swap to vit_b_32 (¼ tokens) or vit_tiny/small. Keep the rest the same: