Script used to FineTune a pretrained ViT TorchVision Model

In [1]:
import torchvision
import torch

from torchvision.io import read_image
from torchvision.models import *
from torch import nn
from collections import OrderedDict
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Initialize the model with pretrained weights (Image1k), then change the MLP head (classification) to a new one for our dataset

In [2]:
weights = ViT_B_16_Weights.DEFAULT
model = vit_b_16(weights=weights)

#generate new head (for 10 classes)
heads: OrderedDict[str, nn.Linear] = OrderedDict()
heads["head"] = nn.Linear(model.hidden_dim, 10)
seq_heads = nn.Sequential(heads)
nn.init.zeros_(seq_heads.head.weight)
nn.init.zeros_(seq_heads.head.bias)

#add the head
model.heads = seq_heads
model.to(device)

Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:05<00:00, 65.4MB/s]


VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

Download dataset Cifar10, and prepare it for training

In [3]:
preprocess = weights.transforms(antialias=True)

training_set = torchvision.datasets.CIFAR10("Cifar10/train", transform=preprocess, train=True, download=True)
test_set = torchvision.datasets.CIFAR10("Cifar10/train", transform=preprocess, train=False, download=True)

train, valid =torch.utils.data.random_split(training_set, [45000,5000])

training_loader = torch.utils.data.DataLoader(train, batch_size=10, shuffle=True)
validation_loader = torch.utils.data.DataLoader(valid, batch_size=8, shuffle=False)



Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to Cifar10/train/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 44022771.82it/s]


Extracting Cifar10/train/cifar-10-python.tar.gz to Cifar10/train
Files already downloaded and verified


Set params for training (changeable)

In [4]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0002, momentum=0.9)
epochs = 3

Functions to train one epoch and for accuracy

In [5]:
def accuracy_score(correct, total):
    return correct/total

In [8]:
def train_one_epoch(epoch_index):
    correct = 0
    instances = 45000

    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        #accuracy score
        out = torch.argmax(outputs, dim=1)
        correct += ((out == labels).count_nonzero().item())

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            running_loss = 0.

    #last batch (only 500 passes)
    last_loss = running_loss / 500
    print('  batch {} loss: {}'.format(4500, last_loss))
    return last_loss, accuracy_score(correct, instances)

Actual code to train (finetune) the whole ViT

In [7]:
epoch_number = 0
best_vloss = 1_000_000.

for epoch in range(epochs):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss, accuracy_epoch = train_one_epoch(epoch_number)


    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata
            vinputs, vlabels = vinputs.to(device), vlabels.to(device)
            voutputs = model(vinputs)
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
    print(f"Accuracy score: {accuracy_epoch}")
    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}'.format(epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1

EPOCH 1:


UnboundLocalError: ignored