# Finetuning the Vision Transformer (ViT)

https://huggingface.co/docs/transformers/v4.27.1/model_doc/vit 

## Data

CIFAR10 dataset has the classes: ‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’. The images in CIFAR-10 are of size 3x32x32, i.e. 3-channel color images of 32x32 pixels in size.

In [1]:
from datasets import load_dataset

cifar10dataset = load_dataset("cifar10", split="train[:5000]")

In [2]:
#print(type(cifar10dataset))
#print(cifar10dataset.keys())

In [3]:
cifar10dataset = cifar10dataset.train_test_split(test_size=0.2)

In [4]:
cifar10dataset["train"][0]

{'img': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>,
 'label': 2}

In [5]:
print(type(cifar10dataset))
print(cifar10dataset.keys())
print(type(cifar10dataset["train"]))
print(type(cifar10dataset["train"][0]))
print(cifar10dataset["train"][0].keys())

<class 'datasets.dataset_dict.DatasetDict'>
dict_keys(['train', 'test'])
<class 'datasets.arrow_dataset.Dataset'>
<class 'dict'>
dict_keys(['img', 'label'])


In [6]:
labels = cifar10dataset["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

In [7]:
id2label[str(9)]

'truck'

## Training

In [8]:
from transformers import AutoImageProcessor

checkpoint = "google/vit-base-patch16-224-in21k"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

2023-08-10 14:55:18.724071: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [9]:
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])

In [10]:
def transforms(examples):
    examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["img"]]
    del examples["img"]
    return examples

In [11]:
cifar10dataset = cifar10dataset.with_transform(transforms)

In [12]:
from transformers import DefaultDataCollator

data_collator = DefaultDataCollator()

In [13]:
import evaluate

accuracy = evaluate.load("accuracy")

In [14]:
import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [15]:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [16]:
training_args = TrainingArguments(
    output_dir="cifar10_vit",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=8, # memory error with 16
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=cifar10dataset["train"],
    eval_dataset=cifar10dataset["test"],
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
)

trainer.train()



  0%|          | 0/375 [00:00<?, ?it/s]

{'loss': 2.2843, 'learning_rate': 1.3157894736842106e-05, 'epoch': 0.08}
{'loss': 2.2422, 'learning_rate': 2.6315789473684212e-05, 'epoch': 0.16}
{'loss': 2.1939, 'learning_rate': 3.9473684210526316e-05, 'epoch': 0.24}
{'loss': 2.0765, 'learning_rate': 4.9703264094955494e-05, 'epoch': 0.32}
{'loss': 1.9237, 'learning_rate': 4.821958456973294e-05, 'epoch': 0.4}
{'loss': 1.789, 'learning_rate': 4.673590504451038e-05, 'epoch': 0.48}
{'loss': 1.5902, 'learning_rate': 4.525222551928784e-05, 'epoch': 0.56}
{'loss': 1.4639, 'learning_rate': 4.3768545994065286e-05, 'epoch': 0.64}
{'loss': 1.3287, 'learning_rate': 4.228486646884273e-05, 'epoch': 0.72}
{'loss': 1.2626, 'learning_rate': 4.080118694362018e-05, 'epoch': 0.8}
{'loss': 1.1657, 'learning_rate': 3.9317507418397627e-05, 'epoch': 0.88}
{'loss': 1.1251, 'learning_rate': 3.783382789317508e-05, 'epoch': 0.96}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 1.0026872158050537, 'eval_accuracy': 0.876, 'eval_runtime': 19.1198, 'eval_samples_per_second': 52.302, 'eval_steps_per_second': 6.538, 'epoch': 1.0}
{'loss': 0.9672, 'learning_rate': 3.635014836795252e-05, 'epoch': 1.04}
{'loss': 0.9249, 'learning_rate': 3.4866468842729974e-05, 'epoch': 1.12}
{'loss': 0.9008, 'learning_rate': 3.338278931750742e-05, 'epoch': 1.2}
{'loss': 0.841, 'learning_rate': 3.189910979228487e-05, 'epoch': 1.28}
{'loss': 0.8292, 'learning_rate': 3.0415430267062318e-05, 'epoch': 1.36}
{'loss': 0.8332, 'learning_rate': 2.8931750741839762e-05, 'epoch': 1.44}
{'loss': 0.752, 'learning_rate': 2.744807121661721e-05, 'epoch': 1.52}
{'loss': 0.7919, 'learning_rate': 2.5964391691394662e-05, 'epoch': 1.6}
{'loss': 0.7498, 'learning_rate': 2.4480712166172106e-05, 'epoch': 1.68}
{'loss': 0.6989, 'learning_rate': 2.2997032640949558e-05, 'epoch': 1.76}
{'loss': 0.6843, 'learning_rate': 2.1513353115727002e-05, 'epoch': 1.84}
{'loss': 0.6911, 'learning_rate': 2.00296

  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 0.6348932981491089, 'eval_accuracy': 0.901, 'eval_runtime': 19.2151, 'eval_samples_per_second': 52.042, 'eval_steps_per_second': 6.505, 'epoch': 2.0}
{'loss': 0.6432, 'learning_rate': 1.706231454005935e-05, 'epoch': 2.08}
{'loss': 0.5861, 'learning_rate': 1.5578635014836794e-05, 'epoch': 2.16}
{'loss': 0.6165, 'learning_rate': 1.4094955489614246e-05, 'epoch': 2.24}
{'loss': 0.6443, 'learning_rate': 1.2611275964391692e-05, 'epoch': 2.32}
{'loss': 0.5798, 'learning_rate': 1.112759643916914e-05, 'epoch': 2.4}
{'loss': 0.533, 'learning_rate': 9.643916913946588e-06, 'epoch': 2.48}
{'loss': 0.5872, 'learning_rate': 8.160237388724036e-06, 'epoch': 2.56}
{'loss': 0.5816, 'learning_rate': 6.676557863501484e-06, 'epoch': 2.64}
{'loss': 0.5828, 'learning_rate': 5.192878338278932e-06, 'epoch': 2.72}
{'loss': 0.5558, 'learning_rate': 3.7091988130563796e-06, 'epoch': 2.8}
{'loss': 0.5587, 'learning_rate': 2.225519287833828e-06, 'epoch': 2.88}
{'loss': 0.5291, 'learning_rate': 7.4183976

  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 0.5763659477233887, 'eval_accuracy': 0.906, 'eval_runtime': 19.4613, 'eval_samples_per_second': 51.384, 'eval_steps_per_second': 6.423, 'epoch': 3.0}
{'train_runtime': 625.2, 'train_samples_per_second': 19.194, 'train_steps_per_second': 0.6, 'train_loss': 1.0144184761047363, 'epoch': 3.0}


TrainOutput(global_step=375, training_loss=1.0144184761047363, metrics={'train_runtime': 625.2, 'train_samples_per_second': 19.194, 'train_steps_per_second': 0.6, 'train_loss': 1.0144184761047363, 'epoch': 3.0})

Pretty good accyracy for only a subset of dataset and 3 epochs.