This is a quick notebook to just illustrate the usage of LoRA+ using Trainer. Note that it is a very simple finetuning task for the ViT model so there is not much possible performance gain using LoRA+ over LoRA and loraplus_lr_ratio should be set near one. 
- See README for more info about setting loraplus_lr_ratio.
- See the glue folder for examples where LoRA+ leads to real performance gains.

In [1]:
import transformers
import accelerate
import peft

In [2]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
model_checkpoint = "google/vit-base-patch16-224-in21k"  # pre-trained model from which to fine-tune

In [None]:
from datasets import load_dataset

dataset = load_dataset("food101", split="train[:5000]")

README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


train-00000-of-00008.parquet:   0%|          | 0.00/490M [00:00<?, ?B/s]

KeyboardInterrupt: 

In [None]:
labels = dataset.features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = i
    id2label[i] = label

id2label[2]

'baklava'

In [None]:
from transformers import AutoImageProcessor

image_processor = AutoImageProcessor.from_pretrained(model_checkpoint)
image_processor

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

ViTImageProcessor {
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "ViTImageProcessor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}

In [None]:
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
train_transforms = Compose(
    [
        RandomResizedCrop(image_processor.size["height"]),
        RandomHorizontalFlip(),
        ToTensor(),
        normalize,
    ]
)

val_transforms = Compose(
    [
        Resize(image_processor.size["height"]),
        CenterCrop(image_processor.size["height"]),
        ToTensor(),
        normalize,
    ]
)


def preprocess_train(example_batch):
    """Apply train_transforms across a batch."""
    example_batch["pixel_values"] = [train_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch


def preprocess_val(example_batch):
    """Apply val_transforms across a batch."""
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

In [None]:
# split up training into training + validation
splits = dataset.train_test_split(test_size=0.1)
train_ds = splits["train"]
val_ds = splits["test"]

In [None]:
train_ds.set_transform(preprocess_train)
val_ds.set_transform(preprocess_val)

In [None]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
    )

In [None]:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

model = AutoModelForImageClassification.from_pretrained(
    model_checkpoint,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True,  # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
)
print_trainable_parameters(model)

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 85876325 || all params: 85876325 || trainable%: 100.00


In [None]:
from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=["query", "value"],
    lora_dropout=0.1,
    use_original_init=False,
    bias="none",
    modules_to_save=["classifier"],
)
lora_model = get_peft_model(model, config)
print_trainable_parameters(lora_model)

trainable params: 667493 || all params: 86543818 || trainable%: 0.77


In [None]:
from transformers import TrainingArguments, Trainer


model_name = model_checkpoint.split("/")[-1]
batch_size = 128

args = TrainingArguments(
    f"{model_name}-finetuned-lora-food101",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-3,
    lr_scheduler_type="linear",
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=batch_size,
    fp16=True,
    num_train_epochs=5,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to="none",
    seed=0,
    overwrite_output_dir=True,
    push_to_hub=False,
    label_names=["labels"],
)

In [None]:
import numpy as np
import evaluate

metric = evaluate.load("accuracy")


# the compute_metrics function takes a Named Tuple as input:
# predictions, which are the logits of the model as Numpy arrays,
# and label_ids, which are the ground-truth labels as Numpy arrays.
def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions"""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

In [None]:
import torch


def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
)

In [None]:
train_results = trainer.train()



Epoch,Training Loss,Validation Loss,Accuracy
0,No log,3.161141,0.826
1,No log,1.912777,0.808
2,No log,1.236503,0.852
4,2.150400,0.570221,0.872
5,2.150400,0.4807,0.884
6,2.150400,0.440412,0.904
7,2.150400,0.432524,0.908




In [None]:
model = AutoModelForImageClassification.from_pretrained(
    model_checkpoint,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True,  # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
)
print_trainable_parameters(model)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 85876325 || all params: 85876325 || trainable%: 100.00


In [None]:
config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=["query", "value"],
    lora_dropout=0.1,
    use_original_init=False,
    bias="none",
    modules_to_save=["classifier"],
)
lora_model = get_peft_model(model, config)
print_trainable_parameters(lora_model)

trainable params: 667493 || all params: 86543818 || trainable%: 0.77


In [None]:
from lora_plus import LoraPlusTrainingArguments, LoraPlusTrainer


model_name = model_checkpoint.split("/")[-1]
batch_size = 128

lp_args = LoraPlusTrainingArguments(
    f"{model_name}-finetuned-loraplus-food101",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=4e-3,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=batch_size,
    fp16=True,
    num_train_epochs=5,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to="none",
    seed=0,
    overwrite_output_dir=True,
    push_to_hub=False,
    label_names=["labels"],
    lr_scheduler_type="linear",
    loraplus_lr_embedding=1e-06,
    loraplus_lr_ratio=1.25,
)

In [None]:
lp_trainer = LoraPlusTrainer(
    model,
    lp_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
    data_collator=collate_fn
)

In [None]:
lp_train_results = lp_trainer.train()



Epoch,Training Loss,Validation Loss,Accuracy
0,No log,1.062158,0.836
1,No log,0.732831,0.85
2,No log,0.501488,0.862
4,0.803100,0.302434,0.898
5,0.803100,0.273427,0.908
6,0.803100,0.260725,0.908
7,0.803100,0.257989,0.912


Checkpoint destination directory vit-base-patch16-224-in21k-finetuned-loraplus-food101/checkpoint-2 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory vit-base-patch16-224-in21k-finetuned-loraplus-food101/checkpoint-4 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory vit-base-patch16-224-in21k-finetuned-loraplus-food101/checkpoint-6 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory vit-base-patch16-224-in21k-finetuned-loraplus-food101/checkpoint-9 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory vit-base-patch16-224-in21k-finetuned-loraplus-food101/checkpoint-11 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory vit-base-patch16-224-in21k-finetuned