In [None]:
# vit_finetune_fakedata.py
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, models
from torchvision.models import ViT_B_16_Weights

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ===== Dữ liệu: FakeData 224x224 (3 lớp) =====
weights = ViT_B_16_Weights.DEFAULT
preprocess = weights.transforms()
NUM_CLASSES = 3

train_ds = datasets.FakeData(size=600, image_size=(3,224,224),
                             num_classes=NUM_CLASSES, transform=preprocess)
val_ds   = datasets.FakeData(size=120, image_size=(3,224,224),
                             num_classes=NUM_CLASSES, transform=preprocess)

train_dl = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=2)
val_dl   = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=2)

# ===== Model: ViT-B/16 pretrained + thay head cuối =====
model = models.vit_b_16(weights=weights)
in_feats = model.heads.head.in_features          # classifier của ViT trong torchvision
model.heads.head = nn.Linear(in_feats, NUM_CLASSES)
model = model.to(DEVICE)

# (Tuỳ chọn) Freeze backbone khi data nhỏ:
# for name, p in model.named_parameters():
#     if not name.startswith("heads"):
#         p.requires_grad = False

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)

def run_epoch(dl, train=True):
    model.train(train)
    total, correct, loss_sum = 0, 0, 0.0
    for x, y in dl:
        x, y = x.to(DEVICE), y.to(DEVICE)
        if train: optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        if train:
            loss.backward()
            optimizer.step()
        loss_sum += loss.item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        total += x.size(0)
    return loss_sum/total, correct/total

EPOCHS = 3
for ep in range(1, EPOCHS+1):
    tr_loss, tr_acc = run_epoch(train_dl, True)
    va_loss, va_acc = run_epoch(val_dl, False)
    print(f"[Epoch {ep}] train_loss={tr_loss:.4f} acc={tr_acc:.3f} | val_loss={va_loss:.4f} acc={va_acc:.3f}")

# ===== Inference nhanh một batch validation =====
model.eval()
x, y = next(iter(val_dl))
x = x.to(DEVICE)
with torch.no_grad():
    pred = model(x).argmax(1).cpu()
print("GT[:8]:", y[:8].tolist())
print("PR[:8]:", pred[:8].tolist())
