In [None]:
from transformers import ViTConfig, ViTModel, ViTImageProcessor, ViTForImageClassification

In [None]:
IMAGE_SIZE = 64

processor = ViTImageProcessor(size=IMAGE_SIZE)
processor

In [None]:
from datasets import load_dataset

dataset = load_dataset("uoft-cs/cifar100")
dataset

In [None]:
def process_example(example):
    inputs = processor(example['img'], return_tensors='pt')
    inputs['labels'] = example['fine_label']
    return inputs

process_example(dataset["train"][0])

In [None]:
def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = processor([x for x in example_batch['img']], return_tensors='pt')

    # Don't forget to include the labels!
    inputs['labels'] = example_batch['fine_label']
    return inputs

In [None]:
prepared_ds = dataset.with_transform(transform)

In [None]:
prepared_ds["train"][0:2]

In [None]:
import torch

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

In [None]:
import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)


In [None]:
prepared_ds["train"].features

In [None]:
labels = prepared_ds['train'].features['fine_label'].names

configuration = ViTConfig(
    num_hidden_layers=3,
    num_attention_heads=16,
    hidden_size=512,
    patch_size=4,
    image_size=IMAGE_SIZE,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

model = ViTForImageClassification(configuration)

In [None]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="./vit-cifar100",
  per_device_train_batch_size=128,
  per_device_eval_batch_size=128,
  evaluation_strategy="epoch",
  save_strategy='epoch',
  num_train_epochs=100,
  logging_steps=10,
  learning_rate=2e-5,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  weight_decay=1e-4,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["test"],
    tokenizer=processor,
)


In [None]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()