In [1]:
from huggingface_hub import notebook_login

notebook_login()

Token is valid.
Your token has been saved in your configured git credential helpers (store).
Your token has been saved to /home/admin.alahmid@lending.local/.cache/huggingface/token
Login successful


In [2]:
from datasets import load_dataset

food = load_dataset("food101", split="train[:1000]")

Found cached dataset food101 (/home/admin.alahmid@lending.local/.cache/huggingface/datasets/food101/default/0.0.0/7cebe41a80fb2da3f08fcbef769c8874073a86346f7fb96dc0847d4dfc318295)


In [3]:
food = food.train_test_split(test_size=0.2)

In [4]:
food["train"][0]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512>,
 'label': 6}

In [5]:
labels = food["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

In [6]:
id2label[str(79)]

'prime_rib'

In [7]:
from transformers import AutoImageProcessor

checkpoint = "google/vit-base-patch16-224-in21k"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

In [8]:
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])

In [9]:
def transforms(examples):
    examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
    del examples["image"]
    return examples

In [10]:
food = food.with_transform(transforms)

In [11]:
from transformers import DefaultDataCollator

data_collator = DefaultDataCollator()

In [12]:
import evaluate

accuracy = evaluate.load("accuracy")

In [13]:
import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [14]:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
)

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
training_args = TrainingArguments(
    output_dir="my_awesome_food_model",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=food["train"],
    eval_dataset=food["test"],
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
)

trainer.train()

/home/admin.alahmid@lending.local/my_awesome_food_model is already a clone of https://huggingface.co/Mustafa21/my_awesome_food_model. Make sure you pull the latest changes with `repo.git_pull()`.
***** Running training *****
  Num examples = 800
  Num Epochs = 3
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 2
  Total optimization steps = 150
  Number of trainable parameters = 85876325


Epoch,Training Loss,Validation Loss,Accuracy
1,2.0523,1.922598,0.935
2,1.3718,1.342208,0.995
3,1.2298,1.233451,0.985


***** Running Evaluation *****
  Num examples = 200
  Batch size = 8
Saving model checkpoint to my_awesome_food_model/checkpoint-50
Configuration saved in my_awesome_food_model/checkpoint-50/config.json
Model weights saved in my_awesome_food_model/checkpoint-50/pytorch_model.bin
Image processor saved in my_awesome_food_model/checkpoint-50/preprocessor_config.json
Image processor saved in my_awesome_food_model/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 200
  Batch size = 8
Saving model checkpoint to my_awesome_food_model/checkpoint-100
Configuration saved in my_awesome_food_model/checkpoint-100/config.json
Model weights saved in my_awesome_food_model/checkpoint-100/pytorch_model.bin
Image processor saved in my_awesome_food_model/checkpoint-100/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 200
  Batch size = 8
Saving model checkpoint to my_awesome_food_model/checkpoint-150
Configuration saved in my_awesome_food_model/checkpoint-150

TrainOutput(global_step=150, training_loss=2.019646905263265, metrics={'train_runtime': 358.8465, 'train_samples_per_second': 6.688, 'train_steps_per_second': 0.418, 'total_flos': 1.8614579687424e+17, 'train_loss': 2.019646905263265, 'epoch': 3.0})

In [17]:
trainer.push_to_hub()

Saving model checkpoint to my_awesome_food_model
Configuration saved in my_awesome_food_model/config.json
Model weights saved in my_awesome_food_model/pytorch_model.bin
Image processor saved in my_awesome_food_model/preprocessor_config.json
Several commits (2) will be pushed upstream.
The progress bars may be unreliable.


Upload file pytorch_model.bin:   0%|          | 32.0k/328M [00:00<?, ?B/s]

remote: Scanning LFS files for validity...        
remote: LFS file scan complete.        
To https://huggingface.co/Mustafa21/my_awesome_food_model
   9ef3d49..75cdd34  main -> main

To https://huggingface.co/Mustafa21/my_awesome_food_model
   75cdd34..9e13155  main -> main



'https://huggingface.co/Mustafa21/my_awesome_food_model/commit/75cdd34ed529083ac997aa849ff4f169d4593602'