# Finetuning the Vision Transformer (ViT)

https://huggingface.co/docs/transformers/v4.27.1/model_doc/vit 

## Data

CIFAR10 dataset has the classes: ‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’. The images in CIFAR-10 are of size 3x32x32, i.e. 3-channel color images of 32x32 pixels in size.

In [1]:
from datasets import load_dataset

cifar10dataset = load_dataset("cifar10", split="train[:5000]")
cifar10dataset = cifar10dataset.train_test_split(test_size=0.2)

In [2]:
#print(type(cifar10dataset))
#print(cifar10dataset.keys())

In [4]:
cifar10dataset["train"][0]

{'img': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>,
 'label': 2}

In [5]:
print(type(cifar10dataset))
print(cifar10dataset.keys())
print(type(cifar10dataset["train"]))
print(type(cifar10dataset["train"][0]))
print(cifar10dataset["train"][0].keys())

<class 'datasets.dataset_dict.DatasetDict'>
dict_keys(['train', 'test'])
<class 'datasets.arrow_dataset.Dataset'>
<class 'dict'>
dict_keys(['img', 'label'])


In [6]:
labels = cifar10dataset["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

In [7]:
id2label[str(9)]

'truck'

## Training on a sample dataset

In [8]:
from transformers import AutoImageProcessor

checkpoint = "google/vit-base-patch16-224-in21k"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

2023-08-10 14:55:18.724071: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [9]:
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])

def transforms(examples):
    examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["img"]]
    del examples["img"]
    return examples

cifar10dataset = cifar10dataset.with_transform(transforms)

In [12]:
from transformers import DefaultDataCollator

data_collator = DefaultDataCollator()

In [13]:
import evaluate
import numpy as np

accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [15]:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [16]:
training_args = TrainingArguments(
    output_dir="cifar10_vit",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=8, # memory error with 16
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=cifar10dataset["train"],
    eval_dataset=cifar10dataset["test"],
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
)

trainer.train()



  0%|          | 0/375 [00:00<?, ?it/s]

{'loss': 2.2843, 'learning_rate': 1.3157894736842106e-05, 'epoch': 0.08}
{'loss': 2.2422, 'learning_rate': 2.6315789473684212e-05, 'epoch': 0.16}
{'loss': 2.1939, 'learning_rate': 3.9473684210526316e-05, 'epoch': 0.24}
{'loss': 2.0765, 'learning_rate': 4.9703264094955494e-05, 'epoch': 0.32}
{'loss': 1.9237, 'learning_rate': 4.821958456973294e-05, 'epoch': 0.4}
{'loss': 1.789, 'learning_rate': 4.673590504451038e-05, 'epoch': 0.48}
{'loss': 1.5902, 'learning_rate': 4.525222551928784e-05, 'epoch': 0.56}
{'loss': 1.4639, 'learning_rate': 4.3768545994065286e-05, 'epoch': 0.64}
{'loss': 1.3287, 'learning_rate': 4.228486646884273e-05, 'epoch': 0.72}
{'loss': 1.2626, 'learning_rate': 4.080118694362018e-05, 'epoch': 0.8}
{'loss': 1.1657, 'learning_rate': 3.9317507418397627e-05, 'epoch': 0.88}
{'loss': 1.1251, 'learning_rate': 3.783382789317508e-05, 'epoch': 0.96}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 1.0026872158050537, 'eval_accuracy': 0.876, 'eval_runtime': 19.1198, 'eval_samples_per_second': 52.302, 'eval_steps_per_second': 6.538, 'epoch': 1.0}
{'loss': 0.9672, 'learning_rate': 3.635014836795252e-05, 'epoch': 1.04}
{'loss': 0.9249, 'learning_rate': 3.4866468842729974e-05, 'epoch': 1.12}
{'loss': 0.9008, 'learning_rate': 3.338278931750742e-05, 'epoch': 1.2}
{'loss': 0.841, 'learning_rate': 3.189910979228487e-05, 'epoch': 1.28}
{'loss': 0.8292, 'learning_rate': 3.0415430267062318e-05, 'epoch': 1.36}
{'loss': 0.8332, 'learning_rate': 2.8931750741839762e-05, 'epoch': 1.44}
{'loss': 0.752, 'learning_rate': 2.744807121661721e-05, 'epoch': 1.52}
{'loss': 0.7919, 'learning_rate': 2.5964391691394662e-05, 'epoch': 1.6}
{'loss': 0.7498, 'learning_rate': 2.4480712166172106e-05, 'epoch': 1.68}
{'loss': 0.6989, 'learning_rate': 2.2997032640949558e-05, 'epoch': 1.76}
{'loss': 0.6843, 'learning_rate': 2.1513353115727002e-05, 'epoch': 1.84}
{'loss': 0.6911, 'learning_rate': 2.00296

  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 0.6348932981491089, 'eval_accuracy': 0.901, 'eval_runtime': 19.2151, 'eval_samples_per_second': 52.042, 'eval_steps_per_second': 6.505, 'epoch': 2.0}
{'loss': 0.6432, 'learning_rate': 1.706231454005935e-05, 'epoch': 2.08}
{'loss': 0.5861, 'learning_rate': 1.5578635014836794e-05, 'epoch': 2.16}
{'loss': 0.6165, 'learning_rate': 1.4094955489614246e-05, 'epoch': 2.24}
{'loss': 0.6443, 'learning_rate': 1.2611275964391692e-05, 'epoch': 2.32}
{'loss': 0.5798, 'learning_rate': 1.112759643916914e-05, 'epoch': 2.4}
{'loss': 0.533, 'learning_rate': 9.643916913946588e-06, 'epoch': 2.48}
{'loss': 0.5872, 'learning_rate': 8.160237388724036e-06, 'epoch': 2.56}
{'loss': 0.5816, 'learning_rate': 6.676557863501484e-06, 'epoch': 2.64}
{'loss': 0.5828, 'learning_rate': 5.192878338278932e-06, 'epoch': 2.72}
{'loss': 0.5558, 'learning_rate': 3.7091988130563796e-06, 'epoch': 2.8}
{'loss': 0.5587, 'learning_rate': 2.225519287833828e-06, 'epoch': 2.88}
{'loss': 0.5291, 'learning_rate': 7.4183976

  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 0.5763659477233887, 'eval_accuracy': 0.906, 'eval_runtime': 19.4613, 'eval_samples_per_second': 51.384, 'eval_steps_per_second': 6.423, 'epoch': 3.0}
{'train_runtime': 625.2, 'train_samples_per_second': 19.194, 'train_steps_per_second': 0.6, 'train_loss': 1.0144184761047363, 'epoch': 3.0}


TrainOutput(global_step=375, training_loss=1.0144184761047363, metrics={'train_runtime': 625.2, 'train_samples_per_second': 19.194, 'train_steps_per_second': 0.6, 'train_loss': 1.0144184761047363, 'epoch': 3.0})

Pretty good accyracy for only a subset of dataset and 3 epochs. Now lets train the model on full dataset and see what accuracy we acheave.

## Training on full dataset

In [1]:
from datasets import load_dataset

cifar10dataset = load_dataset("cifar10", split="train")
cifar10dataset = cifar10dataset.train_test_split(test_size=0.2)

labels = cifar10dataset["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

In [2]:
from transformers import AutoImageProcessor
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
from transformers import DefaultDataCollator
import evaluate
import numpy as np


checkpoint = "google/vit-base-patch16-224-in21k"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)


normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])


def transforms(examples):
    examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["img"]]
    del examples["img"]
    return examples

cifar10dataset = cifar10dataset.with_transform(transforms)

data_collator = DefaultDataCollator()


accuracy = evaluate.load("accuracy")


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

2023-08-11 13:22:23.367673: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
)


training_args = TrainingArguments(
    output_dir="cifar10_vit_fulldataset",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=8, # memory error with 16
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=8,
    num_train_epochs=6,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=cifar10dataset["train"],
    eval_dataset=cifar10dataset["test"],
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
)

trainer.train()

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/7500 [00:00<?, ?it/s]

{'loss': 2.2828, 'learning_rate': 6.666666666666667e-07, 'epoch': 0.01}
{'loss': 2.2906, 'learning_rate': 1.3333333333333334e-06, 'epoch': 0.02}
{'loss': 2.278, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.02}
{'loss': 2.2837, 'learning_rate': 2.666666666666667e-06, 'epoch': 0.03}
{'loss': 2.2677, 'learning_rate': 3.3333333333333333e-06, 'epoch': 0.04}
{'loss': 2.2653, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.05}
{'loss': 2.24, 'learning_rate': 4.666666666666667e-06, 'epoch': 0.06}
{'loss': 2.2499, 'learning_rate': 5.333333333333334e-06, 'epoch': 0.06}
{'loss': 2.2307, 'learning_rate': 6e-06, 'epoch': 0.07}
{'loss': 2.2147, 'learning_rate': 6.666666666666667e-06, 'epoch': 0.08}
{'loss': 2.1864, 'learning_rate': 7.333333333333334e-06, 'epoch': 0.09}
{'loss': 2.172, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.1}
{'loss': 2.1325, 'learning_rate': 8.666666666666668e-06, 'epoch': 0.1}
{'loss': 2.1309, 'learning_rate': 9.333333333333334e-06, 'epoch': 0.11}
{'loss': 2.

  0%|          | 0/1250 [00:00<?, ?it/s]

{'eval_loss': 0.3983166813850403, 'eval_accuracy': 0.8878, 'eval_runtime': 187.3272, 'eval_samples_per_second': 53.383, 'eval_steps_per_second': 6.673, 'epoch': 1.0}
{'loss': 0.4203, 'learning_rate': 4.6222222222222224e-05, 'epoch': 1.01}
{'loss': 0.3852, 'learning_rate': 4.6148148148148154e-05, 'epoch': 1.02}
{'loss': 0.338, 'learning_rate': 4.607407407407408e-05, 'epoch': 1.02}
{'loss': 0.3507, 'learning_rate': 4.600000000000001e-05, 'epoch': 1.03}
{'loss': 0.3259, 'learning_rate': 4.592592592592593e-05, 'epoch': 1.04}
{'loss': 0.3766, 'learning_rate': 4.585185185185185e-05, 'epoch': 1.05}
{'loss': 0.3715, 'learning_rate': 4.577777777777778e-05, 'epoch': 1.06}
{'loss': 0.4094, 'learning_rate': 4.5703703703703706e-05, 'epoch': 1.06}
{'loss': 0.3696, 'learning_rate': 4.5629629629629636e-05, 'epoch': 1.07}
{'loss': 0.4317, 'learning_rate': 4.555555555555556e-05, 'epoch': 1.08}
{'loss': 0.3515, 'learning_rate': 4.548148148148149e-05, 'epoch': 1.09}
{'loss': 0.4006, 'learning_rate': 4.540

  0%|          | 0/1250 [00:00<?, ?it/s]

{'eval_loss': 0.323506236076355, 'eval_accuracy': 0.9005, 'eval_runtime': 186.8173, 'eval_samples_per_second': 53.528, 'eval_steps_per_second': 6.691, 'epoch': 2.0}
{'loss': 0.2986, 'learning_rate': 3.6962962962962966e-05, 'epoch': 2.01}
{'loss': 0.2455, 'learning_rate': 3.688888888888889e-05, 'epoch': 2.02}
{'loss': 0.2893, 'learning_rate': 3.681481481481482e-05, 'epoch': 2.02}
{'loss': 0.3079, 'learning_rate': 3.674074074074074e-05, 'epoch': 2.03}
{'loss': 0.1966, 'learning_rate': 3.6666666666666666e-05, 'epoch': 2.04}
{'loss': 0.237, 'learning_rate': 3.6592592592592596e-05, 'epoch': 2.05}
{'loss': 0.2133, 'learning_rate': 3.651851851851852e-05, 'epoch': 2.06}
{'loss': 0.2126, 'learning_rate': 3.644444444444445e-05, 'epoch': 2.06}
{'loss': 0.2996, 'learning_rate': 3.637037037037037e-05, 'epoch': 2.07}
{'loss': 0.1878, 'learning_rate': 3.62962962962963e-05, 'epoch': 2.08}
{'loss': 0.2906, 'learning_rate': 3.6222222222222225e-05, 'epoch': 2.09}
{'loss': 0.2003, 'learning_rate': 3.61481

  0%|          | 0/1250 [00:00<?, ?it/s]

{'eval_loss': 0.24086838960647583, 'eval_accuracy': 0.9237, 'eval_runtime': 187.9114, 'eval_samples_per_second': 53.217, 'eval_steps_per_second': 6.652, 'epoch': 3.0}
{'loss': 0.2268, 'learning_rate': 2.7703703703703706e-05, 'epoch': 3.01}
{'loss': 0.1756, 'learning_rate': 2.7629629629629632e-05, 'epoch': 3.02}
{'loss': 0.1872, 'learning_rate': 2.7555555555555555e-05, 'epoch': 3.02}
{'loss': 0.2433, 'learning_rate': 2.7481481481481482e-05, 'epoch': 3.03}
{'loss': 0.2581, 'learning_rate': 2.7407407407407408e-05, 'epoch': 3.04}
{'loss': 0.2343, 'learning_rate': 2.733333333333333e-05, 'epoch': 3.05}
{'loss': 0.2189, 'learning_rate': 2.725925925925926e-05, 'epoch': 3.06}
{'loss': 0.2442, 'learning_rate': 2.7185185185185184e-05, 'epoch': 3.06}
{'loss': 0.1812, 'learning_rate': 2.7111111111111114e-05, 'epoch': 3.07}
{'loss': 0.1923, 'learning_rate': 2.7037037037037037e-05, 'epoch': 3.08}
{'loss': 0.2177, 'learning_rate': 2.696296296296296e-05, 'epoch': 3.09}
{'loss': 0.2621, 'learning_rate':

  0%|          | 0/1250 [00:00<?, ?it/s]

{'eval_loss': 0.23980094492435455, 'eval_accuracy': 0.9211, 'eval_runtime': 190.1542, 'eval_samples_per_second': 52.589, 'eval_steps_per_second': 6.574, 'epoch': 4.0}
{'loss': 0.1608, 'learning_rate': 1.8444444444444445e-05, 'epoch': 4.01}
{'loss': 0.229, 'learning_rate': 1.837037037037037e-05, 'epoch': 4.02}
{'loss': 0.2165, 'learning_rate': 1.8296296296296298e-05, 'epoch': 4.02}
{'loss': 0.1501, 'learning_rate': 1.8222222222222224e-05, 'epoch': 4.03}
{'loss': 0.2129, 'learning_rate': 1.814814814814815e-05, 'epoch': 4.04}
{'loss': 0.2636, 'learning_rate': 1.8074074074074074e-05, 'epoch': 4.05}
{'loss': 0.1724, 'learning_rate': 1.8e-05, 'epoch': 4.06}
{'loss': 0.1948, 'learning_rate': 1.7925925925925927e-05, 'epoch': 4.06}
{'loss': 0.1536, 'learning_rate': 1.7851851851851853e-05, 'epoch': 4.07}
{'loss': 0.1855, 'learning_rate': 1.777777777777778e-05, 'epoch': 4.08}
{'loss': 0.1509, 'learning_rate': 1.7703703703703706e-05, 'epoch': 4.09}
{'loss': 0.1794, 'learning_rate': 1.7629629629629

  0%|          | 0/1250 [00:00<?, ?it/s]

{'eval_loss': 0.2027403563261032, 'eval_accuracy': 0.9357, 'eval_runtime': 188.1285, 'eval_samples_per_second': 53.155, 'eval_steps_per_second': 6.644, 'epoch': 5.0}
{'loss': 0.1498, 'learning_rate': 9.185185185185186e-06, 'epoch': 5.01}
{'loss': 0.2139, 'learning_rate': 9.111111111111112e-06, 'epoch': 5.02}
{'loss': 0.173, 'learning_rate': 9.037037037037037e-06, 'epoch': 5.02}
{'loss': 0.1746, 'learning_rate': 8.962962962962963e-06, 'epoch': 5.03}
{'loss': 0.1617, 'learning_rate': 8.88888888888889e-06, 'epoch': 5.04}
{'loss': 0.169, 'learning_rate': 8.814814814814815e-06, 'epoch': 5.05}
{'loss': 0.162, 'learning_rate': 8.740740740740741e-06, 'epoch': 5.06}
{'loss': 0.1978, 'learning_rate': 8.666666666666668e-06, 'epoch': 5.06}
{'loss': 0.1938, 'learning_rate': 8.592592592592593e-06, 'epoch': 5.07}
{'loss': 0.1746, 'learning_rate': 8.518518518518519e-06, 'epoch': 5.08}
{'loss': 0.1394, 'learning_rate': 8.444444444444446e-06, 'epoch': 5.09}
{'loss': 0.1816, 'learning_rate': 8.3703703703

  0%|          | 0/1250 [00:00<?, ?it/s]

{'eval_loss': 0.2116428166627884, 'eval_accuracy': 0.9343, 'eval_runtime': 188.6502, 'eval_samples_per_second': 53.008, 'eval_steps_per_second': 6.626, 'epoch': 6.0}
{'train_runtime': 12210.5226, 'train_samples_per_second': 19.655, 'train_steps_per_second': 0.614, 'train_loss': 0.33515866717497506, 'epoch': 6.0}


TrainOutput(global_step=7500, training_loss=0.33515866717497506, metrics={'train_runtime': 12210.5226, 'train_samples_per_second': 19.655, 'train_steps_per_second': 0.614, 'train_loss': 0.33515866717497506, 'epoch': 6.0})

We got the eval_accuracy of 0.9343 from training on full dataset.        

**The accuracy has been increased by approximately 3.12%, reflecting the improvement obtained through using more data and further tuning.**