In [None]:
!pip install -q transformers datasets evaluate

In [None]:
from datasets import load_dataset

train_ds, test_ds = load_dataset("cifar10", split=["train[:5000]", "test[:2000]"])

In [None]:
train_ds

In [None]:
train_ds.features

In [None]:
train_ds[0]["img"]

In [None]:
train_ds[0]["label"]

In [None]:
id2label = {id:label for id, label in enumerate(train_ds.features["label"].names)}
label2id = {label:id for id, label in id2label.items()}

In [None]:
id2label, label2id

In [None]:
id2label[train_ds[0]["label"]]

In [None]:
from transformers import ViTImageProcessor
import numpy as np


processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

def preprocess_images(data):
  images = data["img"]
  inputs = processor(images=images, return_tensors="tf")
  data["pixel_values"] = inputs["pixel_values"]
  return data

train_ds = train_ds.map(preprocess_images, batched=True)
test_ds = test_ds.map(preprocess_images, batched=True)

In [None]:
processor(train_ds[:10]['img'])["pixel_values"][0].shape

In [None]:
from transformers import TFViTForImageClassification

model = TFViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
                                                    id2label=id2label,
                                                    label2id=label2id,
                                                    num_labels=len(id2label))

In [None]:
model.summary()

In [None]:
from transformers import create_optimizer

batch_size = 16
num_epochs = 5
steps_per_epoch = len(train_ds) // batch_size
total_train_steps = int(steps_per_epoch * num_epochs)
optimizer, schedule = create_optimizer(init_lr=2e-5, num_warmup_steps=0, num_train_steps=total_train_steps)

In [None]:
from transformers import default_data_collator

train_dataset = model.prepare_tf_dataset(
    train_ds,
    shuffle=True,
    batch_size=32,
    collate_fn=default_data_collator,
)

test_dataset = model.prepare_tf_dataset(
    test_ds,
    shuffle=False,
    batch_size=32,
    collate_fn=default_data_collator,
)

In [None]:
model.compile(optimizer=optimizer)

In [None]:
import evaluate
import numpy as np

accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
  predictions, labels = eval_pred
  predictions = np.argmax(predictions, axis=1)
  return accuracy.compute(predictions=predictions, references=labels)

In [None]:
from transformers.keras_callbacks import KerasMetricCallback

metric_callback = KerasMetricCallback(metric_fn=compute_metrics, eval_dataset=test_dataset)

In [None]:
model.fit(x=train_dataset, validation_data=test_dataset, epochs=num_epochs, callbacks=[metric_callback])