In [None]:
# @title Configuration
model_id = "facebook/data2vec-vision-base-ft1k"  # @param {type:"string"}

## 📦 Packages and Basic Setup
---

In [None]:
!nvidia-smi

In [None]:
%%capture
!pip install -U datasets evaluate

In [None]:
import evaluate
import numpy as np
import torch
from datasets import load_dataset
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    Resize,
    ToTensor,
)
from transformers import (
    AutoConfig,
    AutoFeatureExtractor,
    AutoModelForImageClassification,
    Trainer,
    TrainingArguments,
)

## 💿 Load Dataset
---

In [None]:
import datasets

datasets.logging.set_verbosity_error()

In [None]:
%%capture
test_ds = load_dataset("mrm8488/ImageNet1K-val", split="train")

In [None]:
!wget https://huggingface.co/{model_id}/resolve/main/config.json

In [None]:
import json

model_config = None

with open("config.json", "r") as f:
    model_config = json.load(f)

In [None]:
label2id = model_config["label2id"]
id2label = model_config["id2label"]

In [None]:
metric = evaluate.load("accuracy")


def compute_metrics(p):
    return metric.compute(
        predictions=np.argmax(p.predictions, axis=1), references=p.label_ids
    )

In [None]:
%%capture
config = AutoConfig.from_pretrained(
    model_id,
    num_labels=1000,
    label2id=label2id,
    id2label=id2label,
    finetuning_task="image-classification",
)

model = AutoModelForImageClassification.from_pretrained(model_id, config=config)

feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)

In [None]:
normalize = Normalize(
    mean=feature_extractor.image_mean, std=feature_extractor.image_std
)

_val_transforms = Compose(
    [
        Resize(feature_extractor.size["height"]),
        CenterCrop(feature_extractor.size["height"]),
        ToTensor(),
        normalize,
    ]
)

In [None]:
def val_transforms(example_batch):
    example_batch["pixel_values"] = [
        _val_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]
    ]
    return example_batch


test_ds.set_transform(val_transforms)

In [None]:
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

## 👨‍⚖️ Evaluation

In [None]:
training_args = TrainingArguments(
    output_dir="/content/imagenet1k-val",
    overwrite_output_dir=True,
    remove_unused_columns=False,
    do_train=False,
    do_eval=True,
    eval_strategy="epoch",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=7,
    logging_strategy="steps",
    report_to="none",
    logging_steps=100,
    save_strategy="epoch",
    save_total_limit=3,
    load_best_model_at_end=True,
    seed=42,
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    eval_dataset=test_ds,
    compute_metrics=compute_metrics,
    processing_class=feature_extractor,
    data_collator=collate_fn,
)

In [None]:
trainer.evaluate()