<a href="https://colab.research.google.com/github/Baijiong-Lin/LoRA-Torch/blob/main/examples/Finetune_open_clip_with_LoRA_Torch_on_CIFAR10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### This example demonstrates how to apply LoRA-Torch to ``nn.MultiheadAttention`` in OpenCLIP. We greatly appreciate [Viet Q. Vo](https://vietvo89.github.io/)'s valuable contribution.

In [1]:
!pip install open-clip-torch
!pip install git+https://github.com/Baijiong-Lin/LoRA-Torch

Collecting git+https://github.com/Baijiong-Lin/LoRA-Torch
  Cloning https://github.com/Baijiong-Lin/LoRA-Torch to /tmp/pip-req-build-1a5zm6vx
  Running command git clone --filter=blob:none --quiet https://github.com/Baijiong-Lin/LoRA-Torch /tmp/pip-req-build-1a5zm6vx
  Resolved https://github.com/Baijiong-Lin/LoRA-Torch to commit 3b6f10a3bdebfb0da1abeb4c265f914ed06759e4
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [2]:
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
import open_clip
import loratorch as lora
from tqdm import tqdm

### A. Load Pre-trained Model

In [3]:
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='openai')
model = model.cuda()
tokenizer = open_clip.get_tokenizer('ViT-B-32')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [4]:
# prompt: count trainable parameters of model?

def count_parameters(model):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    vision_params = sum(p.numel() for p in model.visual.transformer.parameters() if p.requires_grad)
    text_params = sum(p.numel() for p in model.transformer.parameters() if p.requires_grad)
    embed_params = sum(p.numel() for p in model.token_embedding.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters - Full model: {trainable_params:,}")
    print(f"Trainable parameters - Vision: {vision_params:,}")
    print(f"Trainable parameters - Text: {text_params:,}")
    print(f"Trainable parameters - embedding: {embed_params:,}")

In [5]:
print('Original model before adding lora')
count_parameters(model)

Original model before adding lora
Total parameters: 151,277,313
Trainable parameters - Full model: 151,277,313
Trainable parameters - Vision: 85,054,464
Trainable parameters - Text: 37,828,608
Trainable parameters - embedding: 25,296,896


### B. Load CIFAR-10

In [6]:
# prompt: load cifar10 dataset

from torchvision.datasets import CIFAR10

train_dataset = CIFAR10(
    root="./data", train=True, download=True,
    transform=preprocess
)
test_dataset = CIFAR10(
    root="./data", train=False, download=True,
    transform=preprocess
)

batch_size_train = 256
batch_size_test = 256
train_loader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size_test, shuffle=False, num_workers=4)




### C. Fine-tune OpenCLIP with LoRA

_Note:_

Please make sure ``loratorch.MultiheadAttention`` uses the same input parameter values as [`nn.MultiheadAttention`](https://docs.pytorch.org/docs/2.6/generated/torch.nn.MultiheadAttention.html#multiheadattention).

For exmaple, the default value for batch_first in `nn.MultiheadAttention` is `False`, but `open_clip` sets it to `True` in some `attn` layers. The discussion of this can be found [here](https://github.com/Baijiong-Lin/LoRA-Torch/issues/6#issuecomment-2954122864).

The best way of employing `loratorch.MultiheadAttention` is the following:
```python
lora_multihead = lora.MultiheadAttention(r=r,
                        lora_alpha=lora_alpha,
                        enable_lora=enable_lora,
                        embed_dim=multihead.embed_dim,
                        num_heads=multihead.num_heads,
                        dropout=multihead.dropout,
                        bias=True if hasattr(multihead, "in_proj_bias") else False,
                        add_bias_kv=False if multihead.bias_k==None else True,
                        add_zero_attn=multihead.add_zero_attn,
                        kdim=multihead.kdim,
                        vdim=multihead.vdim,
                        batch_first=multihead.batch_first)
```

#### Apply LoRA to `attn` and `mlp`

In [7]:
def apply_lora_attn_mlp(model, encoder_type='visual', rank=16, lora_alpha=32, mlp=True, attn=True):
    if encoder_type == 'visual':
        encoder = model.visual.transformer
    elif encoder_type == 'text':
        encoder = model.transformer
    else:
        raise ValueError("Invalid encoder_type. Choose 'visual' or 'text'.")

    enable_lora=['q', 'k', 'v', 'o']
    for i, resblock in enumerate(encoder.resblocks):
        if hasattr(resblock, 'attn') and attn:
            multihead = resblock.attn
            lora_multihead = lora.MultiheadAttention(r=rank,
                                    lora_alpha=lora_alpha,
                                    enable_lora=enable_lora,
                                    embed_dim=multihead.embed_dim,
                                    num_heads=multihead.num_heads,
                                    dropout=multihead.dropout,
                                    bias=True if hasattr(multihead, "in_proj_bias") else False,
                                    add_bias_kv=False if multihead.bias_k==None else True,
                                    add_zero_attn=multihead.add_zero_attn,
                                    kdim=multihead.kdim,
                                    vdim=multihead.vdim,
                                    batch_first=multihead.batch_first)
            lora_multihead.load_state_dict(multihead.state_dict(), strict=False)
            resblock.attn = lora_multihead

        if hasattr(resblock, 'mlp') and mlp:
            old_mlp_fc=resblock.mlp.c_fc
            old_mlp_proj=resblock.mlp.c_proj
            new_mlp_fc = lora.Linear(
                old_mlp_fc.in_features,
                old_mlp_fc.out_features,
                bias=True if hasattr(old_mlp_fc, "bias") else False,
                r=rank,
                lora_alpha=lora_alpha,
            )
            new_mlp_proj = lora.Linear(
                old_mlp_proj.in_features,
                old_mlp_proj.out_features,
                bias=True if hasattr(old_mlp_proj, "bias") else False,
                r=rank,
                lora_alpha=lora_alpha,
            )
            new_mlp_fc.load_state_dict(old_mlp_fc.state_dict(),strict=False)
            new_mlp_proj.load_state_dict(old_mlp_proj.state_dict(),strict=False)
            resblock.mlp.c_fc = new_mlp_fc
            resblock.mlp.c_proj = new_mlp_proj

    lora.mark_only_lora_as_trainable(model)
    return model

In [8]:
apply_lora_attn_mlp(model, encoder_type='visual', rank=16, lora_alpha=32, mlp=True, attn=True)
tokenizer = open_clip.get_tokenizer('ViT-B-32')

In [9]:
for name, param in model.visual.transformer.resblocks[0].named_parameters():
    print(name, param.requires_grad)

# after adding lora
print("\nAfter adding LoRA to Attn+MLP:")
count_parameters(model)

ln_1.weight False
ln_1.bias False
attn.in_proj_weight False
attn.in_proj_bias False
attn.o_lora_A True
attn.o_lora_B True
attn.qkv_lora_A True
attn.qkv_lora_B True
attn.out_proj.weight False
attn.out_proj.bias False
ln_2.weight False
ln_2.bias False
mlp.c_fc.weight False
mlp.c_fc.bias False
mlp.c_fc.w_lora_A True
mlp.c_fc.w_lora_B True
mlp.c_proj.weight False
mlp.c_proj.bias False
mlp.c_proj.w_lora_A True
mlp.c_proj.w_lora_B True

After adding LoRA to Attn+MLP:
Total parameters: 153,636,609
Trainable parameters - Full model: 2,359,296
Trainable parameters - Vision: 2,359,296
Trainable parameters - Text: 0
Trainable parameters - embedding: 0


In [10]:
# Tokenizer and text embeddings
model.cuda()

tokenizer = open_clip.get_tokenizer("ViT-B-32")
classnames = train_dataset.classes
text_inputs = tokenizer([f"a photo of a {label}" for label in classnames]).cuda()
with torch.no_grad():
    text_features = model.encode_text(text_inputs)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

In [11]:
# Train loop
model.train()
for epoch in range(3):
    total_loss = 0
    correct = 0
    total = 0
    for images, labels in tqdm(train_loader):
        images, labels = images.cuda(), labels.cuda()
        image_features = model.encode_image(images)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        logits = image_features @ text_features.t()
        loss = nn.CrossEntropyLoss()(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        # (!!!) reregister model param to ensure they are in model.state_dict() and model.parameters()
        # (!!!) Without this line, the performance does not be affected but you will find that some weights are missing in model.state_dict() and model.parameters()
        lora.register_model_param_after_backward(model)

    acc = correct / total
    print(f"Epoch {epoch+1}: Loss={total_loss:.4f}, Accuracy={acc:.4f}")

100%|██████████| 196/196 [06:43<00:00,  2.06s/it]


Epoch 1: Loss=397.7533, Accuracy=0.9529


100%|██████████| 196/196 [06:40<00:00,  2.04s/it]


Epoch 2: Loss=387.0708, Accuracy=0.9796


100%|██████████| 196/196 [06:39<00:00,  2.04s/it]

Epoch 3: Loss=386.2361, Accuracy=0.9901



