In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_set = torchvision.datasets.CIFAR10(root='./data/cifar-10-finetune', train=True, download=True, transform=transform)
test_set = torchvision.datasets.CIFAR10(root='./data/cifar-10-finetune', train=False, download=True, transform=transform)

train_loader = DataLoader(train_set, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

Files already downloaded and verified
Files already downloaded and verified


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = torchvision.models.vit_b_16(pretrained=True)
model = model.to(device)

model.heads = nn.Linear(model.heads[0].in_features, 10)
model.heads = model.heads.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)



In [3]:
def train(model, criterion, optimizer, loader, epochs=10):
    model.train()
    for epoch in range(epochs):
        for images, labels in loader:
            images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}, Loss: {loss.item()}")

def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy

In [None]:
train(model, criterion, optimizer, train_loader, epochs=1)
accuracy = evaluate(model, test_loader)
print(f"Test Accuracy: {accuracy}%")