In [1]:
import torch
import torchvision
from torchvision.models import vit_b_16, ViT_B_16_Weights
import numpy as np

In [2]:
if torch.cuda.is_available():
    dev = "cuda:0"
elif torch.backends.mps.is_available():
    dev = "mps"
else:
    dev = "cpu"
device = torch.device(dev)
device

device(type='mps')

In [3]:
weights = ViT_B_16_Weights.IMAGENET1K_V1
# weights = ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1

preprocess = weights.transforms()

load CIFAR-10

In [4]:
trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=preprocess)
testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=preprocess)

use a subset of the training set for faster training (makes no sense with more than one epoch)

In [5]:
# subset_size = 1000
# indices = np.random.choice(len(trainset), subset_size, replace=False)
# trainset = torch.utils.data.Subset(trainset, indices)

In [6]:
batch_size = 64
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4)

In [7]:
model = vit_b_16(weights=weights)
model

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

modify the last classification layer to output 10 classes

In [8]:
model.heads.head = torch.nn.Linear(model.heads.head.in_features, 10)
model

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

freeze all layers except classifier head

(It's better to finetune the backbone, too. But training takes much longer then.)

In [9]:
for param in model.parameters():
    param.requires_grad = False
for param in model.heads.head.parameters():
    param.requires_grad = True

In [10]:
sum(p.numel() for p in model.parameters() if p.requires_grad) # parameters to be fitted

7690

In [11]:
optimizer = torch.optim.Adam(model.parameters())

In [12]:
model = model.to(device)

In [13]:
def fit(epochs, model, optimizer, train_dl):
    loss_func = torch.nn.CrossEntropyLoss()

    # loop over epochs
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        # loop over mini-batches
        for X_mb, y_mb in train_dl:
            X_mb, y_mb = X_mb.to(device), y_mb.to(device)
            y_hat = model(X_mb)

            loss = loss_func(y_hat, y_mb)
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

            running_loss += loss.item()

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss / len(trainloader):.4f}")

    return model

In [14]:
epochs = 3
model = fit(epochs, model, optimizer, trainloader)

Epoch [1/3], Loss: 0.2424
Epoch [2/3], Loss: 0.1471
Epoch [3/3], Loss: 0.1291


evaluate

In [15]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in testloader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")

Test Accuracy: 95.08%
