In [None]:
# resnet_minimal_transfer.py
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from torchvision.models import ResNet18_Weights

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

# ===== 1) Dữ liệu nhanh: FakeData (3 lớp, ảnh 224x224) =====
# Dùng transform chuẩn theo bộ trọng số pretrained của ResNet-18
weights = ResNet18_Weights.DEFAULT
preprocess = weights.transforms()  # gồm resize 224, ToTensor, Normalize, v.v.

NUM_CLASSES = 3
train_ds = datasets.FakeData(size=600, image_size=(3, 224, 224), num_classes=NUM_CLASSES, transform=preprocess)
val_ds   = datasets.FakeData(size=120, image_size=(3, 224, 224), num_classes=NUM_CLASSES, transform=preprocess)

train_dl = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=2)
val_dl   = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=2)

# ===== 2) Model: ResNet-18 pretrained + thay fc cuối =====
model = models.resnet18(weights=weights)       # tải backbone pretrained ImageNet
in_feats = model.fc.in_features
model.fc = nn.Linear(in_feats, NUM_CLASSES)    # thay head cho số lớp của ta
model = model.to(DEVICE)

# (Tuỳ chọn) Freeze backbone nếu dataset nhỏ
# for p in model.layer1.parameters(): p.requires_grad = False
# for p in model.layer2.parameters(): p.requires_grad = False
# for p in model.layer3.parameters(): p.requires_grad = False
# for p in model.layer4.parameters(): p.requires_grad = False
# -> Khi freeze, optimizer chỉ nên nhận model.fc.parameters()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# ===== 3) Vòng lặp train ngắn gọn =====
def run_epoch(dl, train=True):
    model.train(train)
    total, correct, loss_sum = 0, 0, 0.0
    for x, y in dl:
        x, y = x.to(DEVICE), y.to(DEVICE)
        if train:
            optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        if train:
            loss.backward()
            optimizer.step()
        loss_sum += loss.item() * x.size(0)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += x.size(0)
    return loss_sum/total, correct/total

EPOCHS = 3
for ep in range(1, EPOCHS+1):
    tr_loss, tr_acc = run_epoch(train_dl, train=True)
    va_loss, va_acc = run_epoch(val_dl, train=False)
    print(f"[Epoch {ep}] train_loss={tr_loss:.4f} acc={tr_acc:.3f} | val_loss={va_loss:.4f} acc={va_acc:.3f}")

# ===== 4) Inference: dự đoán 1 batch ảnh =====
model.eval()
x, y = next(iter(val_dl))
x = x.to(DEVICE)
with torch.no_grad():
    logits = model(x)
pred = logits.argmax(dim=1).cpu()

print("Ground truth (first 8):", y[:8].tolist())
print("Predictions  (first 8):", pred[:8].tolist())
