In [None]:
# Using PEFT with timm

`peft` allows us to train any model with LoRA as long as the layer type is supported. Since `Conv2D` is one of the supported layer types, it makes sense to test it on image models.

In this short notebook, we will demonstrate this with an image classification task using [`timm`](https://huggingface.co/docs/timm/index).

## Imports

Make sure that you have the latest version of `peft` installed. To ensure that, run this in your Python environment:
    
    python -m pip install --upgrade peft
    
Also, ensure that `timm` is installed:

    python -m pip install --upgrade timm

In [None]:
import timm
import torch
from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

In [None]:
import peft
from datasets import load_dataset

In [None]:
torch.manual_seed(0)

## Loading the pre-trained base model

We use a small pretrained `timm` model, `PoolFormer`. Find more info on its [model card](https://huggingface.co/timm/poolformer_m36.sail_in1k).

In [None]:
model_id_timm = "timm/poolformer_m36.sail_in1k"

We tell `timm` that we deal with 3 classes, to ensure that the classification layer has the correct size.

In [None]:
model = timm.create_model(model_id_timm, pretrained=True, num_classes=3)

These are the transformations steps necessary to process the image.

In [None]:
transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))

## Data

For this exercise, we use the "beans" dataset. More details on the dataset can be found on [its datasets page](https://huggingface.co/datasets/beans). For our purposes, what's important is that we have image inputs and the target we're trying to predict is one of three classes for each image.

In [None]:
ds = load_dataset("beans")

In [None]:
ds_train = ds["train"]
ds_valid = ds["validation"]

In [None]:
ds_train[0]["image"]

We define a small processing function which is responsible for loading and transforming the images, as well as extracting the labels.

In [None]:
def process(batch):
    x = torch.cat([transform(img).unsqueeze(0) for img in batch["image"]])
    y = torch.tensor(batch["labels"])
    return {"x": x, "y": y}

In [None]:
ds_train.set_transform(process)
ds_valid.set_transform(process)

In [None]:
train_loader = torch.utils.data.DataLoader(ds_train, batch_size=32)
valid_loader = torch.utils.data.DataLoader(ds_valid, batch_size=32)

## Training

This is just a function that performs the train loop, nothing fancy happening.

In [None]:
def train(model, optimizer, criterion, train_dataloader, valid_dataloader, epochs):
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for batch in train_dataloader:
            xb, yb = batch["x"], batch["y"]
            xb, yb = xb.to(device), yb.to(device)
            outputs = model(xb)
            lsm = torch.nn.functional.log_softmax(outputs, dim=-1)
            loss = criterion(lsm, yb)
            train_loss += loss.detach().float()
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        model.eval()
        valid_loss = 0
        correct = 0
        n_total = 0
        for batch in valid_dataloader:
            xb, yb = batch["x"], batch["y"]
            xb, yb = xb.to(device), yb.to(device)
            with torch.no_grad():
                outputs = model(xb)
            lsm = torch.nn.functional.log_softmax(outputs, dim=-1)
            loss = criterion(lsm, yb)
            valid_loss += loss.detach().float()
            correct += (outputs.argmax(-1) == yb).sum().item()
            n_total += len(yb)

        train_loss_total = (train_loss / len(train_dataloader)).item()
        valid_loss_total = (valid_loss / len(valid_dataloader)).item()
        valid_acc_total = correct / n_total
        print(f"{epoch=:<2}  {train_loss_total=:.4f}  {valid_loss_total=:.4f}  {valid_acc_total=:.4f}")

### Selecting which layers to fine-tune with LoRA

Let's take a look at the layers of our model. We only print the first 30, since there are quite a few:

In [None]:
[(n, type(m)) for n, m in model.named_modules()][:30]

Most of these layers are not good targets for LoRA, but we see a couple that should interest us. Their names are `'stages.0.blocks.0.mlp.fc1'`, etc. With a bit of regex, we can match them easily.

Also, we should inspect the name of the classification layer, since we want to train that one too!

In [None]:
[(n, type(m)) for n, m in model.named_modules()][-5:]

    config = peft.LoraConfig(
        r=8,
        target_modules=r".*\.mlp\.fc\d|head\.fc",
    )

Okay, this gives us all the information we need to fine-tune this model. With a bit of regex, we match the convolutional layers that should be targeted for LoRA. We also want to train the classification layer `'head.fc'` (without LoRA), so we add it to the `modules_to_save`.

In [None]:
config = peft.LoraConfig(r=8, target_modules=r".*\.mlp\.fc\d", modules_to_save=["head.fc"])

Finally, let's create the `peft` model, the optimizer and criterion, and we can get started. As shown below, less than 2% of the model's total parameters are updated thanks to `peft`.

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
peft_model = peft.get_peft_model(model, config).to(device)
optimizer = torch.optim.Adam(peft_model.parameters(), lr=2e-4)
criterion = torch.nn.CrossEntropyLoss()
peft_model.print_trainable_parameters()

In [None]:
%time train(peft_model, optimizer, criterion, train_loader, valid_dataloader=valid_loader, epochs=10)

We get an accuracy of ~0.97, despite only training a tiny amount of parameters. That's a really nice result.

## Sharing the model through Hugging Face Hub

### Pushing the model to Hugging Face Hub

If we want to share the fine-tuned weights with the world, we can upload them to Hugging Face Hub like this:

In [None]:
user = "BenjaminB"  # put your user name here
model_name = "peft-lora-with-timm-model"
model_id = f"{user}/{model_name}"

In [None]:
peft_model.push_to_hub(model_id);

As we can see, the adapter size is only 4.3 MB. The original model was 225 MB. That's a very big saving.

### Loading the model from HF Hub

Now, it only takes one step to load the model from HF Hub. To do this, we can use `PeftModel.from_pretrained`, passing our base model and the model ID:

In [None]:
base_model = timm.create_model(model_id_timm, pretrained=True, num_classes=3)
loaded = peft.PeftModel.from_pretrained(base_model, model_id)

In [None]:
x = ds_train[:1]["x"]
y_peft = peft_model(x.to(device))
y_loaded = loaded(x)
torch.allclose(y_peft.cpu(), y_loaded)

### Clean up

Finally, as a clean up step, you may want to delete the repo.

In [None]:
from huggingface_hub import delete_repo

In [None]:
delete_repo(model_id)