## 1. Import Modules and Data
Although the details of pre-training and fine-tuning on ViT are not exactly the same, ViT is trained on ImageNet1k, ImageNet21k, and JFT datasets, which are too large to reproduce. Therefore, we just show how ViT fine-tined on CIFAR-10 dataset.

In [1]:
from data import load_data
import config

dataset = "cifar10"

train_dataloader, valid_dataloader, test_dataloader = load_data(
    dataset, splits=["train", "dev", "test"]
)

base_lr = config.base_lr[dataset]
# we will not use the following setting, see config.py
# total_steps = config.total_steps[dataset]

num_epochs = config.num_epochs
total_steps = num_epochs * len(train_dataloader)
num_epochs, total_steps

Files already downloaded and verified
Files already downloaded and verified


(7, 10724)

## 2. Build Model
As suggested in original paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/pdf/2010.11929), the classifer should be replaced with a zero-initialized $D\times K$ feedforward layer, where $K$ is the number of downstream classes.

In [2]:
import torch
import torch.nn as nn
from modules import ViTForImageClassification

device = torch.device("cuda")
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
model.classifier = nn.Linear(model.classifier.in_features, 10, bias=False)
nn.init.zeros_(model.classifier.weight)
model.to(device)

Loading weights from pretrained ViT: vit-base-patch16-224


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([10, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ModuleDict(
            (attention): SelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): ModuleDict(
              (dense): Linear(in_features=768, out_features=768, bias=True)
            )
          )
          (layernorm_before): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (output): ModuleDict(
            (dense): Linear(in_

## 3. Train Model
For fine-tuning, ViT uses SGD with a momentum of 0.9. Besides, it runs a small grid search over learning rates in ${0.001, 0.003, 0.01, 0.03}$. To do so, ViT uses small sub-splits from the training set (2% for CIFAR) as development set and train on the remaining data. For final results, entire training set is used and evaluate on the testdata. 

All models are fine-tuned with cosine learning rate decay, a batch size of 512, no weight decay, and grad clipping at global norm 1. 

In [3]:
import torch.optim as optim
import config
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm

criterion = nn.CrossEntropyLoss()
initial_sd = {k: v.cpu() for k, v in model.state_dict().items()}


def evaluate(model, dataloader):
    model.eval()
    all_targets = []
    all_predictions = []
    with torch.no_grad():
        for inputs, targets in tqdm(dataloader, desc="Evaluating..."):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            all_targets.extend(targets.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    accuracy = accuracy_score(all_targets, all_predictions)
    return accuracy


def train(model, lr, total_steps, optimizer, scheduler):
    model.train()
    current_step = 0
    criterion = nn.CrossEntropyLoss()
    pbar = tqdm(total=total_steps, desc=f"Training for lr={lr}")
    while current_step < total_steps:
        for batch_idx, (inputs, targets) in enumerate(train_dataloader):
            if current_step > total_steps:
                break

            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss = loss / config.accumulate_grad_batches
            loss.backward()

            if (
                batch_idx + 1
            ) % config.accumulate_grad_batches == 0 or batch_idx == total_steps:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()

            current_step += 1
            pbar.update(1)

    pbar.close()


def configure_optimizer(lr, total_steps):
    optimizer = optim.SGD(
        model.parameters(),
        lr=lr,
        momentum=config.momentum,
        weight_decay=config.weight_decay,
    )
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps)
    return optimizer, scheduler

We will do a grid search among the learning rates suggested in the paper to find the best learning rate. To save time, we will only run a fraction of the steps. You can determine the number of training steps by adjusting the `search_ratio`.

In [4]:
best_lr = 0
best_acc = 0

search_ratio = 0.05
search_steps = int(total_steps * search_ratio)

for lr in base_lr:
    model.load_state_dict({k: v.to(device) for k, v in initial_sd.items()})
    train(model, lr, search_steps, *configure_optimizer(lr, search_steps))
    acc = evaluate(model, valid_dataloader)
    print(f"Learning rate: {lr}, Validation Accuracy: {acc:.4f}")

    if acc > best_acc:
        best_acc = acc
        best_lr = lr

print(f"Best learning rate: {best_lr}")

Training for lr=0.001:   0%|          | 0/536 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/32 [00:00<?, ?it/s]

Learning rate: 0.001, Validation Accuracy: 0.8980


Training for lr=0.003:   0%|          | 0/536 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/32 [00:00<?, ?it/s]

Learning rate: 0.003, Validation Accuracy: 0.9320


Training for lr=0.01:   0%|          | 0/536 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/32 [00:00<?, ?it/s]

Learning rate: 0.01, Validation Accuracy: 0.9510


Training for lr=0.03:   0%|          | 0/536 [00:00<?, ?it/s]

Evaluating...:   0%|          | 0/32 [00:00<?, ?it/s]

Learning rate: 0.03, Validation Accuracy: 0.9620
Best learning rate: 0.03


It seems that 0.03 is the most suitable lr. Now, we can train on full training set.

In [None]:
from torch.utils.data import ConcatDataset, DataLoader

best_lr = 0.03
train_dataloader = DataLoader(
    ConcatDataset([train_dataloader.dataset, valid_dataloader.dataset]),
    batch_size=config.batch_size,
    shuffle=True,
)

model.load_state_dict({k: v.to(device) for k, v in initial_sd.items()})
train(model, best_lr, total_steps, *configure_optimizer(best_lr, total_steps))
evaluate(model, test_dataloader)