In [1]:
%%capture
pip install -q peft transformers datasets

In [2]:
from datasets import load_dataset

ds = load_dataset("food101")

Downloading readme:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/490M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/464M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/472M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/464M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/475M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/470M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/478M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/486M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/423M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/413M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/426M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/75750 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/25250 [00:00<?, ? examples/s]

In [3]:
ds

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 75750
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 25250
    })
})

In [4]:
labels = ds["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = i
    id2label[i] = label

In [5]:
id2label[5]

'beet_salad'

In [6]:
from transformers import AutoImageProcessor

image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")



preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

In [17]:
from torchvision.transforms import (
    CenterCrop,
    Compose, # Compose a sequence of image transformations
    Normalize,
    RandomHorizontalFlip, # Randomly flips image with probability 50%
    RandomResizedCrop, # Radomly crops image. Augments training data
    Resize,
    ToTensor
)

normalize = Normalize(mean = image_processor.image_mean, std = image_processor.image_std)

# This list of transfomrations will be applied to each image in dataset
train_transforms = Compose(
    [
        RandomResizedCrop(image_processor.size["height"]),
        RandomHorizontalFlip(),
        ToTensor(),
        normalize,
    ]
)

val_transforms = Compose(
    [
        Resize(image_processor.size["height"]),
        CenterCrop(image_processor.size["height"]),
        ToTensor(),
        normalize,
    ]
)

def preprocess_train(example_batch):
    example_batch["pixel_values"] = [train_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

def preprocess_val(example_batch):
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

In [18]:
train_ds = ds["train"]
val_ds = ds["validation"]

In [19]:
train_ds

Dataset({
    features: ['image', 'label'],
    num_rows: 75750
})

In [20]:
# preprocess_train and val are not called until
train_ds.set_transform(preprocess_train)
val_ds.set_transform(preprocess_val)

In [21]:
# you’ll need a data collator to create a batch of training and evaluation data and convert the labels to torch.tensor objects.
import torch

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

In [22]:
train_ds[:3] # We see preprocess function is applied on the fly when we access the train_ds

{'image': [<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=384x512>,
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512>,
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x383>],
 'label': [6, 6, 6],
 'pixel_values': [tensor([[[-1.0000, -1.0000, -0.9843,  ..., -0.6235, -0.6157, -0.6314],
           [-1.0000, -1.0000, -1.0000,  ..., -0.6000, -0.6235, -0.6314],
           [-1.0000, -0.9922, -0.9922,  ..., -0.6157, -0.6314, -0.6471],
           ...,
           [ 0.0275,  0.0745,  0.0980,  ..., -0.1137, -0.1294, -0.1451],
           [ 0.0824,  0.0745,  0.0588,  ..., -0.1765, -0.1137, -0.0980],
           [ 0.0588,  0.0980,  0.0980,  ..., -0.1529, -0.0824, -0.0902]],
  
          [[-0.7255, -0.7176, -0.7020,  ..., -0.5529, -0.5451, -0.5765],
           [-0.7255, -0.7176, -0.7020,  ..., -0.5294, -0.5529, -0.5686],
           [-0.7333, -0.7176, -0.7020,  ..., -0.5451, -0.5686, -0.5843],
           ...,
           [-0.0275,  0.0196,  0.0431,  ..., -0.1451, -0.17

In [23]:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

model = AutoModelForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True,
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [24]:
# Every PEFT method requires a configuration that holds all the parameters specifying how the PEFT method should be applied. 
# Once the configuration is setup, pass it to the get_peft_model() function along with the base model to create a trainable PeftModel.
from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=["query", "value"],
    lora_dropout=0.1,
    bias="none",
    modules_to_save=["classifier"],
)
model = get_peft_model(model, config)
model.print_trainable_parameters()


trainable params: 667,493 || all params: 86,543,818 || trainable%: 0.7713


In [25]:
from transformers import TrainingArguments, Trainer


peft_model_id = f"google/vit-base-patch16-224-in21k-lora"
batch_size = 128

args = TrainingArguments(
    peft_model_id,
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-3,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=batch_size,
    fp16=True,
    num_train_epochs=5,
    logging_steps=10,
    load_best_model_at_end=True,
    label_names=["labels"],
)

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=image_processor,
    data_collator=collate_fn,
)
trainer.train()

Epoch,Training Loss,Validation Loss
