In [1]:
import torch
import torchvision
import torchvision.transforms as T
import torch.nn as nn
import torch.optim as optim
import timm
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# CIFAR-10 class labels
cifar10_classes = [
    "airplane", "automobile", "bird", "cat", "deer",
    "dog", "frog", "horse", "ship", "truck"
]

In [3]:
# Load ViT model
model = timm.create_model("vit_tiny_patch16_224", pretrained=True, num_classes=10).to(device)

In [4]:
# Data transforms
transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=(0.5,), std=(0.5,))
])

In [5]:
# CIFAR-10 dataset
train_data = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
test_data = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=False)

In [6]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
# Training pass (1 full loop)
print("Training...")
model.train()
train_correct, train_total = 0, 0
for x, y in train_loader:
    x, y = x.to(device), y.to(device)
    optimizer.zero_grad()
    out = model(x)
    loss = criterion(out, y)
    loss.backward()
    optimizer.step()
    train_correct += (out.argmax(1) == y).sum().item()
    train_total += y.size(0)

train_acc = train_correct / train_total * 100
print(f"Training Accuracy: {train_acc:.2f}%")

Training...


In [None]:
# Testing accuracy
print("Evaluating...")
model.eval()
test_correct, test_total = 0, 0
with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        out = model(x)
        test_correct += (out.argmax(1) == y).sum().item()
        test_total += y.size(0)

test_acc = test_correct / test_total * 100
print(f"Test Accuracy: {test_acc:.2f}%")

In [None]:
# Get current working directory
base_dir = os.getcwd()  # This gives you the directory from where the script is run

# Define relative save directory under current folder
save_dir = os.path.join(base_dir, "saved_models")

# Ensure directory exists
os.makedirs(save_dir, exist_ok=True)

# Save model
model_path = os.path.join(save_dir, "vit_cifar10.pth")
torch.save(model.state_dict(), model_path)

print(f"Model saved successfully at relative path: {os.path.relpath(model_path)}")