In [1]:
from datasets import load_dataset
from huggingface_hub import login
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification,
    DefaultDataCollator,
    TrainingArguments,
    Trainer,
)
import evaluate
import numpy as np

import os

In [3]:
login(os.getenv("HF_READ"))

In [4]:
dataset = load_dataset("ethz/food101")

In [5]:
dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 75750
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 25250
    })
})

In [6]:
labels = dataset['train'].features['label'].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

In [7]:
checkpoint = "google/vit-base-patch16-224-in21k"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


In [8]:
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)
_transfroms = Compose([RandomResizedCrop(size), ToTensor(), normalize])

In [9]:
def transforms(examples):
    examples["pixel_values"] = [_transfroms(img.convert("RGB")) for img in examples["image"]]
    del examples["image"]
    return examples

In [10]:
dataset = dataset.with_transform(transforms)

In [11]:
data_collator = DefaultDataCollator()

In [12]:
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [13]:
model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
login(os.getenv("HF_WRITE"))

In [16]:
training_args = TrainingArguments(
    output_dir="./food_models/",
    remove_unused_columns=False,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    compute_metrics=compute_metrics,
)

trainer.train()

  0%|          | 0/3549 [00:00<?, ?it/s]

{'loss': 4.618, 'grad_norm': 3.305068016052246, 'learning_rate': 1.4084507042253521e-06, 'epoch': 0.01}
{'loss': 4.6231, 'grad_norm': 3.2962327003479004, 'learning_rate': 2.8169014084507042e-06, 'epoch': 0.02}
{'loss': 4.622, 'grad_norm': 3.28930926322937, 'learning_rate': 4.225352112676056e-06, 'epoch': 0.03}
{'loss': 4.6132, 'grad_norm': 3.1374218463897705, 'learning_rate': 5.6338028169014084e-06, 'epoch': 0.03}
{'loss': 4.6044, 'grad_norm': 3.0575778484344482, 'learning_rate': 7.042253521126762e-06, 'epoch': 0.04}
{'loss': 4.5987, 'grad_norm': 3.3007946014404297, 'learning_rate': 8.450704225352112e-06, 'epoch': 0.05}
{'loss': 4.6071, 'grad_norm': 3.1758439540863037, 'learning_rate': 9.859154929577465e-06, 'epoch': 0.06}
{'loss': 4.5925, 'grad_norm': 3.3889970779418945, 'learning_rate': 1.1267605633802817e-05, 'epoch': 0.07}
{'loss': 4.5724, 'grad_norm': 3.249274253845215, 'learning_rate': 1.267605633802817e-05, 'epoch': 0.08}
{'loss': 4.5737, 'grad_norm': 3.2495477199554443, 'learni



{'loss': 4.3955, 'grad_norm': 3.9763519763946533, 'learning_rate': 2.535211267605634e-05, 'epoch': 0.15}
{'loss': 4.3471, 'grad_norm': 4.15232515335083, 'learning_rate': 2.676056338028169e-05, 'epoch': 0.16}
{'loss': 4.3433, 'grad_norm': 4.565159320831299, 'learning_rate': 2.8169014084507046e-05, 'epoch': 0.17}
{'loss': 4.3008, 'grad_norm': 4.455473899841309, 'learning_rate': 2.9577464788732395e-05, 'epoch': 0.18}
{'loss': 4.2614, 'grad_norm': 4.430142402648926, 'learning_rate': 3.0985915492957744e-05, 'epoch': 0.19}
{'loss': 4.2318, 'grad_norm': 4.329217910766602, 'learning_rate': 3.23943661971831e-05, 'epoch': 0.19}
{'loss': 4.2089, 'grad_norm': 5.201883792877197, 'learning_rate': 3.380281690140845e-05, 'epoch': 0.2}
{'loss': 4.152, 'grad_norm': 4.348280906677246, 'learning_rate': 3.5211267605633805e-05, 'epoch': 0.21}
{'loss': 4.1211, 'grad_norm': 4.8135528564453125, 'learning_rate': 3.661971830985916e-05, 'epoch': 0.22}
{'loss': 4.0596, 'grad_norm': 5.0671586990356445, 'learning_ra

  0%|          | 0/1579 [00:00<?, ?it/s]

{'eval_loss': 1.5359952449798584, 'eval_accuracy': 0.7944950495049505, 'eval_runtime': 231.2659, 'eval_samples_per_second': 109.182, 'eval_steps_per_second': 6.828, 'epoch': 1.0}
{'loss': 1.6879, 'grad_norm': 12.862735748291016, 'learning_rate': 3.692861615529117e-05, 'epoch': 1.01}
{'loss': 1.5854, 'grad_norm': 12.344841003417969, 'learning_rate': 3.6772072636192864e-05, 'epoch': 1.01}
{'loss': 1.6177, 'grad_norm': 10.77888298034668, 'learning_rate': 3.6615529117094554e-05, 'epoch': 1.02}
{'loss': 1.5838, 'grad_norm': 8.838829040527344, 'learning_rate': 3.6458985597996245e-05, 'epoch': 1.03}
{'loss': 1.6024, 'grad_norm': 17.48933982849121, 'learning_rate': 3.6302442078897936e-05, 'epoch': 1.04}
{'loss': 1.6184, 'grad_norm': 12.04076099395752, 'learning_rate': 3.6145898559799626e-05, 'epoch': 1.05}
{'loss': 1.577, 'grad_norm': 12.381103515625, 'learning_rate': 3.598935504070132e-05, 'epoch': 1.06}
{'loss': 1.523, 'grad_norm': 10.747654914855957, 'learning_rate': 3.583281152160301e-05, 



{'loss': 0.9633, 'grad_norm': 22.55658721923828, 'learning_rate': 1.9708829054477147e-05, 'epoch': 1.93}
{'loss': 1.0116, 'grad_norm': 18.3155574798584, 'learning_rate': 1.9552285535378838e-05, 'epoch': 1.94}
{'loss': 1.033, 'grad_norm': 16.4659481048584, 'learning_rate': 1.939574201628053e-05, 'epoch': 1.95}
{'loss': 0.9497, 'grad_norm': 23.966915130615234, 'learning_rate': 1.923919849718222e-05, 'epoch': 1.96}
{'loss': 1.037, 'grad_norm': 15.241747856140137, 'learning_rate': 1.9082654978083906e-05, 'epoch': 1.97}
{'loss': 0.9893, 'grad_norm': 16.155290603637695, 'learning_rate': 1.8926111458985597e-05, 'epoch': 1.98}
{'loss': 0.9586, 'grad_norm': 15.206948280334473, 'learning_rate': 1.8769567939887288e-05, 'epoch': 1.99}
{'loss': 1.0097, 'grad_norm': 19.43486976623535, 'learning_rate': 1.861302442078898e-05, 'epoch': 1.99}


  0%|          | 0/1579 [00:00<?, ?it/s]

{'eval_loss': 0.8811041116714478, 'eval_accuracy': 0.8401188118811881, 'eval_runtime': 221.1275, 'eval_samples_per_second': 114.187, 'eval_steps_per_second': 7.141, 'epoch': 2.0}
{'loss': 0.9559, 'grad_norm': 18.8619441986084, 'learning_rate': 1.845648090169067e-05, 'epoch': 2.0}
{'loss': 0.9337, 'grad_norm': 15.42538833618164, 'learning_rate': 1.829993738259236e-05, 'epoch': 2.01}
{'loss': 0.9804, 'grad_norm': 14.350504875183105, 'learning_rate': 1.814339386349405e-05, 'epoch': 2.02}
{'loss': 0.905, 'grad_norm': 15.630205154418945, 'learning_rate': 1.798685034439574e-05, 'epoch': 2.03}
{'loss': 0.933, 'grad_norm': 13.274575233459473, 'learning_rate': 1.7830306825297432e-05, 'epoch': 2.04}
{'loss': 0.9206, 'grad_norm': 13.106250762939453, 'learning_rate': 1.7673763306199123e-05, 'epoch': 2.04}
{'loss': 0.8928, 'grad_norm': 25.216529846191406, 'learning_rate': 1.7517219787100814e-05, 'epoch': 2.05}
{'loss': 1.0232, 'grad_norm': 16.792932510375977, 'learning_rate': 1.7360676268002504e-05



{'loss': 0.8516, 'grad_norm': 20.198699951171875, 'learning_rate': 8.9073262366938e-06, 'epoch': 2.52}
{'loss': 0.8679, 'grad_norm': 17.54818344116211, 'learning_rate': 8.750782717595491e-06, 'epoch': 2.53}
{'loss': 0.8437, 'grad_norm': 21.20258140563965, 'learning_rate': 8.594239198497182e-06, 'epoch': 2.53}
{'loss': 0.8322, 'grad_norm': 31.0272216796875, 'learning_rate': 8.437695679398873e-06, 'epoch': 2.54}
{'loss': 0.8693, 'grad_norm': 15.131173133850098, 'learning_rate': 8.281152160300563e-06, 'epoch': 2.55}
{'loss': 0.9272, 'grad_norm': 12.492947578430176, 'learning_rate': 8.124608641202254e-06, 'epoch': 2.56}
{'loss': 0.8495, 'grad_norm': 12.601189613342285, 'learning_rate': 7.968065122103947e-06, 'epoch': 2.57}
{'loss': 0.8372, 'grad_norm': 23.344053268432617, 'learning_rate': 7.811521603005637e-06, 'epoch': 2.58}
{'loss': 0.7858, 'grad_norm': 18.180421829223633, 'learning_rate': 7.654978083907326e-06, 'epoch': 2.59}
{'loss': 0.8779, 'grad_norm': 20.926069259643555, 'learning_r

  0%|          | 0/1579 [00:00<?, ?it/s]



{'eval_loss': 0.7385702133178711, 'eval_accuracy': 0.8545346534653465, 'eval_runtime': 224.1738, 'eval_samples_per_second': 112.636, 'eval_steps_per_second': 7.044, 'epoch': 3.0}
{'train_runtime': 5860.4726, 'train_samples_per_second': 38.777, 'train_steps_per_second': 0.606, 'train_loss': 1.6977294368050608, 'epoch': 3.0}


TrainOutput(global_step=3549, training_loss=1.6977294368050608, metrics={'train_runtime': 5860.4726, 'train_samples_per_second': 38.777, 'train_steps_per_second': 0.606, 'total_flos': 1.7616838216178074e+19, 'train_loss': 1.6977294368050608, 'epoch': 2.998521647307286})

In [17]:
trainer.push_to_hub(token=os.getenv("HF_WRITE"))

model.safetensors:   0%|          | 0.00/344M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/SABR22/food_models/commit/ee75f7a0f6d1225a209740763506bfd4482eece6', commit_message='End of training', commit_description='', oid='ee75f7a0f6d1225a209740763506bfd4482eece6', pr_url=None, repo_url=RepoUrl('https://huggingface.co/SABR22/food_models', endpoint='https://huggingface.co', repo_type='model', repo_id='SABR22/food_models'), pr_revision=None, pr_num=None)