In [1]:
import numpy as np

from tqdm import tqdm, trange

import torch
import torch.nn as nn
from torch.optim import Adam, SGD
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader

from torchvision import transforms
from torchvision.models import ViT_B_16_Weights, vit_b_16
from torchvision.datasets.cifar import CIFAR10

np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7a75327b9270>

#### Classes

#### Train

In [28]:
del model
torch.cuda.empty_cache()

In [36]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [46]:
model = vit_b_16(ViT_B_16_Weights.IMAGENET1K_V1)

model.heads = nn.Sequential(
    nn.Linear(model.heads.head.in_features, 10)
)

# Freeze all layers
for param in model.parameters():
    param.requires_grad = False

# Unfreeze the last encoder layer and the head
for param in model.encoder.layers[-1].parameters():
    param.requires_grad = True
for param in model.heads.parameters():
    param.requires_grad = True

In [None]:
model.to(device)

In [48]:
transform_train = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_set = CIFAR10(root='./datasets', train=True, download=True, transform=transform_train)
test_set = CIFAR10(root='./datasets', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_set, shuffle=True, batch_size=64)
test_loader = DataLoader(test_set, shuffle=False, batch_size=64)


Files already downloaded and verified
Files already downloaded and verified


In [49]:
n_epochs = 1
lr = 0.0001

optimizer = Adam(model.parameters(), lr=lr)
criterion = CrossEntropyLoss()

for epoch in range(n_epochs):
    train_loss = 0.0
    for i,batch in enumerate(train_loader):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = model(x)
        loss = criterion(y_hat, y)

        batch_loss = loss.detach().cpu().item()
        train_loss += batch_loss / len(train_loader)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i%100==0:
          print(f"Batch {i}/{len(train_loader)} loss: {batch_loss:.03f}")

    print(f"Epoch {epoch + 1}/{n_epochs} loss: {train_loss:.03f}")

Batch 0/782 loss: 2.497
Batch 100/782 loss: 0.111
Batch 200/782 loss: 0.240
Batch 300/782 loss: 0.091
Batch 400/782 loss: 0.153
Batch 500/782 loss: 0.083
Batch 600/782 loss: 0.106
Batch 700/782 loss: 0.173
Epoch 1/1 loss: 0.202


In [50]:
# Test loop
with torch.no_grad():
    correct, total = 0, 0
    test_loss = 0.0
    for batch in tqdm(test_loader, desc="Testing"):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = model(x)
        loss = criterion(y_hat, y)
        test_loss += loss.detach().cpu().item() / len(test_loader)

        correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
        total += len(x)
    print(f"Test loss: {test_loss:.2f}")
    print(f"Test accuracy: {correct / total * 100:.2f}%")

Testing: 100%|██████████| 157/157 [02:08<00:00,  1.22it/s]

Test loss: 0.13
Test accuracy: 95.67%



