In [20]:
#pip install transformers

In [21]:
#pip install transformers[torch]

In [22]:
#pip install torch

In [23]:
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from transformers import TrainingArguments, Trainer
from datasets import load_dataset
import numpy as np
from sklearn.metrics import accuracy_score

In [24]:
# Load CIFAR-10 dataset from the datasets library
dataset = load_dataset("cifar10")

In [25]:
# Define the number of images you want to use for training and testing
num_train_images = 2000
num_test_images = 1000

# Select a subset of the dataset
train_dataset = dataset['train'].select(range(num_train_images))
test_dataset = dataset['test'].select(range(num_test_images))

In [26]:
# Define a feature extractor for ViT model
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")

In [27]:
# Preprocess the dataset
def preprocess_dataset(examples):
    inputs = feature_extractor(examples['img'], return_tensors='pt')
    return {**inputs, 'label': examples['label']}


In [28]:
train_dataset = train_dataset.map(preprocess_dataset)

In [29]:
test_dataset = test_dataset.map(preprocess_dataset)

In [30]:
# Define a ViT model for image classification
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", num_labels=10, ignore_mismatched_sizes=True)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([10, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [31]:
# Define training arguments
training_args = TrainingArguments(
    output_dir="./vit_cifar10",
    evaluation_strategy="steps",
    eval_steps=500,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=3,
    learning_rate=2e-4,
    save_steps=500,
    save_total_limit=3,
    push_to_hub=False,
)

In [32]:
# Define a Trainer for training
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
)

In [33]:
# Train the model
trainer.train()

  0%|          | 0/96 [03:23<?, ?it/s]
  0%|          | 0/96 [00:00<?, ?it/s]

ValueError: too many values to unpack (expected 4)

In [None]:
# Evaluate the model
results = trainer.evaluate()

In [None]:
# Calculate and print accuracy
accuracy = results['eval_accuracy']
print(f"Accuracy: {accuracy:.2f}")

In [None]:
# Save the model
model.save_pretrained("./vit_cifar10_model")

In [None]:
# Load the model
model = ViTForImageClassification.from_pretrained("./vit_cifar10_model")

In [None]:
# Load test data and preprocess
test_data = dataset['test']['img']
test_input = feature_extractor(test_data, return_tensors='pt')

In [None]:
# Make predictions
with torch.no_grad():
    logits = model(**test_input).logits

In [None]:
# Convert logits to predictions
predictions = np.argmax(logits, axis=1)

In [None]:
# Calculate accuracy
true_labels = dataset['test']['label']
accuracy = accuracy_score(true_labels, predictions)
print(f"Test Accuracy: {accuracy:.2f}")