### Example - ViT for image classification

Why resize to 224×224? 

ViT was pretrained at 224; matching size & normalization speeds convergence.

In [12]:
from datasets import load_dataset
from transformers import AutoImageProcessor, ViTForImageClassification, TrainingArguments, Trainer, DefaultDataCollator
from torchvision.transforms import Compose, Resize, RandomHorizontalFlip, ToTensor, Normalize
from PIL import Image
import numpy as np, evaluate, torch

import warnings
warnings.filterwarnings('ignore')

In [13]:
# Data + labels
ds = load_dataset("cifar10")
id2label = {i: n for i, n in enumerate(ds["train"].features["label"].names)}
label2id = {v: k for k, v in id2label.items()}

In [14]:
# Image processor (does resize to 224 + normalization automatically for ViT)
proc = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

# Preprocess with batched map
def preprocess(examples):
    # Ensure everything is PIL or arrays; processor can handle both
    images = [
        img if isinstance(img, Image.Image) else Image.fromarray(img)
        for img in examples["img"]
    ]
    out = proc(images=images)  # don't set return_tensors; datasets will tensorize later
    return {"pixel_values": out["pixel_values"], "labels": examples["label"]}

train_ds = ds["train"].map(preprocess, batched=True, remove_columns=ds["train"].column_names)
test_ds  = ds["test"].map(preprocess,  batched=True, remove_columns=ds["test"].column_names)

# Convert to torch tensors for Trainer
train_ds.set_format(type="torch", columns=["pixel_values", "labels"])
test_ds.set_format(type="torch", columns=["pixel_values", "labels"])


Fetching 1 files: 100%|██████████████████████████████████| 1/1 [00:00<?, ?it/s]
Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.
Fetching 1 files: 100%|█████████████████████████| 1/1 [00:00<00:00, 999.83it/s]
Map: 100%|███████████████████████| 50000/50000 [06:05<00:00, 136.78 examples/s]
Map: 100%|███████████████████████| 10000/10000 [01:22<00:00, 121.86 examples/s]


In [15]:
# Model (pretrained ViT, new 10-class head)
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    num_labels=10, id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True
)

# Baseline: freeze the whole ViT encoder, train only the classification head
for p in model.vit.parameters():
    p.requires_grad = False

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [16]:
# Compute metrics
metric = evaluate.load("accuracy")
def compute_metrics(p):  # tiny helper
    preds = np.argmax(p.predictions, axis=1)
    return metric.compute(predictions=preds, references=p.label_ids)

In [18]:
# Training arguments inputs
args = TrainingArguments(
    output_dir="vit-cifar10",
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=1,
    learning_rate=5e-5,
    weight_decay=0.05,
    warmup_ratio=0.1,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    remove_unused_columns=False,              # If there are any()
    fp16=torch.cuda.is_available(),
)

In [19]:
# TRainer class instantiation and train
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    data_collator=DefaultDataCollator(),
    compute_metrics=compute_metrics,
)

trainer.train()
print(trainer.evaluate())

Epoch,Training Loss,Validation Loss,Accuracy
1,1.6396,1.649166,0.9307


{'eval_loss': 1.6491659879684448, 'eval_accuracy': 0.9307, 'eval_runtime': 188.9428, 'eval_samples_per_second': 52.926, 'eval_steps_per_second': 0.831, 'epoch': 1.0}


In [20]:
# One-image inference demo
ex = test_ds[0]  # already tensors
with torch.no_grad():
    out = model(ex["pixel_values"].unsqueeze(0).to(model.device))
    print("Pred:", id2label[out.logits.argmax(-1).item()], "| True:", id2label[ex["labels"].item()])

# 7) Save
trainer.save_model("vit-cifar10/best")

Pred: cat | True: cat


In [23]:
# If you initialized the image processr as part of training it would get saved else
# save it seapartely

trainer.save_model("vit-cifar10/best")
proc.save_pretrained("vit-cifar10/best")  # writes preprocessor_config.json

['vit-cifar10/best\\preprocessor_config.json']

In [24]:
# Predict function
def predict(image_path, model_dir="vit-cifar10/best"):
    # 1) Load model + processor
    processor = AutoImageProcessor.from_pretrained(model_dir)
    model = ViTForImageClassification.from_pretrained(model_dir)
    model.eval()

    # 2) Load + preprocess image
    img = Image.open(image_path).convert("RGB")
    inputs = processor(images=img, return_tensors="pt")

    # 3) Forward pass
    with torch.no_grad():
        outputs = model(**inputs)
        pred_id = outputs.logits.argmax(-1).item()

    # 4) Map id → label
    return model.config.id2label[pred_id]

# Call function
print(predict("plane.jpg"))


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


airplane
