Finetune ViT on CIFAR-100

# Necessary Libraries

In [1]:
%pip install datasets evaluate

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading evaluate-0.4.3-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31

# Model Checkpoint

In [2]:
model_checkpt = "google/vit-base-patch16-224-in21k" # May be better for fine tuning than non-in21k version

# Load Dataset

In [3]:
from datasets import load_dataset

cifar_100_ds = load_dataset("cifar100")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/9.98k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/119M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/23.8M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [4]:
import evaluate

metric = evaluate.load("accuracy")

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

In [6]:
label2id = {label: idx for idx, label in enumerate(cifar_100_ds["train"].features["fine_label"].names)}
id2label = {idx: label for label, idx in label2id.items()}

# Pre-process Dataset

Note: Data augmentation is much more important for Vision Transformers than CNNs.

In [7]:
from transformers import AutoImageProcessor

image_processor = AutoImageProcessor.from_pretrained(model_checkpt) # Make sure image resize is apppropriate for model

image_processor

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


ViTImageProcessor {
  "do_convert_rgb": null,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "ViTImageProcessor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}

In [8]:
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std) # Normalize color channels

if "height" in image_processor.size: # Crop images to size model expects
    size = (image_processor.size["height"], image_processor.size["width"])
    crop_size = size
    max_size = None

elif "shortest_edge" in image_processor.size:
    size = image_processor.size["shortest_edge"]
    crop_size = (size, size)
    max_size = image_processor.size.get("longest_edge")


train_transforms = Compose(
    [
        RandomResizedCrop(crop_size),
        RandomHorizontalFlip(),
        ToTensor(),
        normalize,
    ]
)

validation_transforms = Compose(
    [
        Resize(size),
        CenterCrop(crop_size),
        ToTensor(),
        normalize,
    ]
)

In [9]:
def preprocess_train(example_batch):
    """
    Apply transformations to images in batch
    """
    example_batch["pixel_values"] = [train_transforms(image.convert("RGB")) for image in example_batch["img"]]
    return example_batch

def preprocess_validation(example_batch):
    """
    Apply transformations to images in batch
    """
    example_batch["pixel_values"] = [validation_transforms(image.convert("RGB")) for image in example_batch["img"]]
    return example_batch

In [10]:
# Apply transformations to dataset

splits = cifar_100_ds["train"].train_test_split(test_size=0.2)

train_ds = splits["train"]
val_ds = splits["test"]

train_ds.set_transform(preprocess_train) # Set_transform only applies when images loaded into RAM, so better to use than map
val_ds.set_transform(preprocess_validation)

In [11]:
train_ds[0] # Make sure pixel values feature has been added

{'img': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>,
 'fine_label': 21,
 'coarse_label': 11,
 'pixel_values': tensor([[[-0.2784, -0.2784, -0.2784,  ..., -0.2627, -0.2627, -0.2627],
          [-0.2784, -0.2784, -0.2784,  ..., -0.2627, -0.2627, -0.2627],
          [-0.2784, -0.2784, -0.2784,  ..., -0.2627, -0.2627, -0.2627],
          ...,
          [-0.9451, -0.9451, -0.9451,  ..., -0.7725, -0.7725, -0.7725],
          [-0.9451, -0.9451, -0.9451,  ..., -0.7725, -0.7725, -0.7725],
          [-0.9451, -0.9451, -0.9451,  ..., -0.7725, -0.7725, -0.7725]],
 
         [[-0.3255, -0.3255, -0.3255,  ..., -0.4824, -0.4824, -0.4824],
          [-0.3255, -0.3255, -0.3255,  ..., -0.4824, -0.4824, -0.4824],
          [-0.3255, -0.3255, -0.3255,  ..., -0.4824, -0.4824, -0.4824],
          ...,
          [-1.0000, -1.0000, -1.0000,  ..., -0.7569, -0.7569, -0.7569],
          [-1.0000, -1.0000, -1.0000,  ..., -0.7569, -0.7569, -0.7569],
          [-1.0000, -1.0000, -1.0000,  ..., -0.756

# Fine Tuning

## Load Model

In [23]:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

model = AutoModelForImageClassification.from_pretrained( # Load model
    model_checkpt,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True, # Need b/c fine-tuning from already fine-tuned model checkpoint
    num_labels=100, # Number of labels to be classified
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Set Hyperparameters

In [20]:
from logging import log
model_name = model_checkpt.split("/")[-1] # Remove google/ part for finetuned model name

output_dir = f"{model_name}-finetuned-cifar100"
batch_size = 32 # Increase to speed up training & evaluation
logging_steps = 10 # Get more granular updates

training_args = TrainingArguments(
    output_dir=output_dir,
    remove_unused_columns=False, # Need b/c image column is necessary to create pixel values
    eval_strategy="epoch",
    save_strategy="epoch", # Save model at each epoch
    lr_scheduler_type="linear",
    learning_rate=5e-5,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4, # Speeds up training
    per_device_eval_batch_size=batch_size,
    num_train_epochs=6, # Need more epochs than CIFAR-10 b/c more complicated
    warmup_ratio=0.1, # 10% of training steps to increase learning rate to desired amount
    logging_steps=logging_steps,
    disable_tqdm=False,
    report_to="none", # Disable WandB logging
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    fp16=True, # Speeds up training
)

## Define Accuracy and Collation Functions

In [21]:
import numpy as np

def compute_metrics(eval_pred):
    """
    Compute accuracy metrics on a batch of predictions
    """
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)

    return metric.compute(predictions=predictions, references=labels)

In [22]:
import torch

def collate(examples: list):
  """
  Create batches
  """

  # Stack all pixel values into single tensor
  pixel_values = torch.stack([example["pixel_values"] for example in examples])

  # Extract labels from each example
  labels = torch.tensor([example["fine_label"] for example in examples])

  return {"pixel_values": pixel_values, "labels": labels}

## Train and Evaluate Model

In [None]:
trainer = Trainer(
    model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    processing_class=image_processor, # Make sure image processor JSON config is saved w/ model
    compute_metrics=compute_metrics,
    data_collator=collate
)

trainer.train()

# Upload model to Hugging Face Hub

In [28]:
from huggingface_hub import notebook_login

notebook_login() # Login to account

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [29]:
kwargs = {
    "dataset_tags": "cifar-100",
    "dataset": "cifar-100",
}

trainer.push_to_hub(commit_message="Train cifar 100 model", **kwargs)

model.safetensors:   0%|          | 0.00/344M [00:00<?, ?B/s]

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

training_args.bin:   0%|          | 0.00/5.37k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/avanishd/vit-base-patch16-224-in21k-finetuned-cifar100/commit/88346020cc1dc384983fbbfc02f5a92576c10ab5', commit_message='Train cifar 100 model', commit_description='', oid='88346020cc1dc384983fbbfc02f5a92576c10ab5', pr_url=None, repo_url=RepoUrl('https://huggingface.co/avanishd/vit-base-patch16-224-in21k-finetuned-cifar100', endpoint='https://huggingface.co', repo_type='model', repo_id='avanishd/vit-base-patch16-224-in21k-finetuned-cifar100'), pr_revision=None, pr_num=None)