In [3]:
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor

transform = transforms.Compose([ToTensor()])

train_data = datasets.cifar.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=transform
)

testset = datasets.cifar.CIFAR10(
    root="./data",
    train=False,
    download=True,
    transform=transform
)


  1%|          | 983k/170M [00:02<07:41, 367kB/s]  


KeyboardInterrupt: 

In [None]:
from torch.utils.data import DataLoader

train_loader= DataLoader(train_data, batch_size=64, shuffle=True)
test_loader= DataLoader(testset, batch_size=64, shuffle=False)

In [3]:
for images, labels in train_loader:
    print(images.shape)
    print(labels.shape)
    break   

torch.Size([64, 3, 32, 32])
torch.Size([64])


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor, Resize
%pip install timm
import timm


Note: you may need to restart the kernel to use updated packages.


In [14]:
transform = transforms.Compose([
    Resize((224,224)),
    ToTensor()
])
train_data = datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=transform
)

test_data = datasets.CIFAR10(
    root="./data",
    train=False,
    download=True,
    transform=transform
)
train_loader = DataLoader(train_data, batch_size=320, shuffle=True)
test_loader  = DataLoader(test_data, batch_size=320, shuffle=False)


In [15]:
model = timm.create_model('vit_tiny_patch16_224', pretrained=True)
model.head = nn.Linear(model.head.in_features, 10)   # CIFAR-10 has 10 classes
model = model.to('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [16]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=3e-4)


In [17]:
def train_epoch(model, loader, optimizer, criterion, device,verbose=True):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for batch_idx, (images, labels) in enumerate(train_loader):

        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, preds = outputs.max(1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        if verbose:
            print(f"Batch {batch_idx}: Loss={loss.item():.4f}")
    return total_loss / len(loader), correct / total


In [18]:
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            _, preds = outputs.max(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    return total_loss / len(loader), correct / total


In [20]:
EPOCHS = 5

for epoch in range(EPOCHS):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc = evaluate(model, test_loader, criterion, device)

    print(f"Epoch {epoch+1}/{EPOCHS}")
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.4f}")


Batch 0: Loss=5.0013
Batch 1: Loss=3.3481
Batch 2: Loss=3.3477
Batch 3: Loss=3.1672
Batch 4: Loss=2.5967
Batch 5: Loss=2.3636
Batch 6: Loss=2.2329
Batch 7: Loss=2.1767
Batch 8: Loss=2.2170
Batch 9: Loss=2.0879
Batch 10: Loss=2.0011
Batch 11: Loss=1.9630
Batch 12: Loss=1.9251
Batch 13: Loss=1.7881
Batch 14: Loss=1.8431
Batch 15: Loss=1.7435
Batch 16: Loss=1.5356
Batch 17: Loss=1.4781
Batch 18: Loss=1.5184
Batch 19: Loss=1.3247
Batch 20: Loss=1.3933
Batch 21: Loss=1.3327
Batch 22: Loss=1.2962
Batch 23: Loss=1.1718
Batch 24: Loss=1.0678
Batch 25: Loss=1.0914
Batch 26: Loss=1.2115
Batch 27: Loss=0.9798
Batch 28: Loss=0.9929
Batch 29: Loss=0.9630
Batch 30: Loss=0.8552
Batch 31: Loss=0.8126
Batch 32: Loss=0.8144
Batch 33: Loss=0.8233
Batch 34: Loss=0.7232
Batch 35: Loss=0.5960
Batch 36: Loss=0.6709
Batch 37: Loss=0.5340
Batch 38: Loss=0.5771
Batch 39: Loss=0.6147
Batch 40: Loss=0.5066
Batch 41: Loss=0.6037
Batch 42: Loss=0.5128
Batch 43: Loss=0.6193
Batch 44: Loss=0.4815
Batch 45: Loss=0.541

In [1]:
torch.save(model.state_dict(), "./models/vit_cifar10.pth")


NameError: name 'torch' is not defined