## Add LoRA to ViT models

In [2]:
import timm
from lora import LoRA_ViT_timm

weightInfo={
    # "small":"WinKawaks/vit-small-patch16-224",
    "base":"vit_base_patch16_224",
    "base_dino":"vit_base_patch16_224.dino", # 21k -> 1k
    "base_sam":"vit_base_patch16_224.sam", # 1k
    "base_mill":"vit_base_patch16_224_miil.in21k_ft_in1k", # 1k
    "base_beit":"beitv2_base_patch16_224.in1k_ft_in22k_in1k",
    "base_clip":"vit_base_patch16_clip_224.laion2b_ft_in1k", # 1k
    "base_deit":"deit_base_distilled_patch16_224", # 1k
    "large":"google/vit-large-patch16-224",
    "large_clip":"vit_large_patch14_clip_224.laion2b_ft_in1k", # laion-> 1k
    "large_beit":"beitv2_large_patch16_224.in1k_ft_in22k_in1k", 
    "huge_clip":"vit_huge_patch14_clip_224.laion2b_ft_in1k", # laion-> 1k
    "giant_eva":"eva_giant_patch14_224.clip_ft_in1k", # laion-> 1k
    "giant_clip":"vit_giant_patch14_clip_224.laion2b",
    "giga_clip":"vit_gigantic_patch14_clip_224.laion2b"
    }

rank = 4
alpha = 8
num_classes = 2

model = timm.create_model(weightInfo["base"], pretrained=True)
melo = LoRA_ViT_timm(model, r=rank, alpha=alpha, num_classes=num_classes)

num_params = sum(p.numel() for p in melo.parameters() if p.requires_grad)
print(f"trainable parameters: {num_params/2**20:.3f}M")
num_params = sum(p.numel() for p in melo.parameters())
print(f"total parameters: {num_params/2**20:.3f}M")

# Save melo
melo.save_lora_parameters("melo.safetensors")

# Load melo
melo.load_lora_parameters("melo.safetensors")

trainable parameters: 0.142M
total parameters: 81.966M


## Use pretrained melo in paper

In [None]:
import timm
from base_vit import ViT
from lora import LoRA_ViT_timm, LoRA_ViT

melo_path = "path to melo weight"
melo_info = melo_path.split("/")[-1].split("_")

if melo_info[0] == "base":
    model = timm.create_model("vit_base_patch16_224", pretrained=True)
    melo = LoRA_ViT_timm(model, r=int(melo_info[3]), alpha=int(melo_info[4]), num_classes=int(melo_info[4]))
    melo.load_lora_parameters(melo_path)
else:
    model = ViT('B_16_imagenet1k')
    melo = LoRA_ViT(model, r=int(melo_info[3]), alpha=int(melo_info[4]), num_classes=int(melo_info[4]))
    melo.load_lora_parameters(melo_path)

## Use multi-melo on ViT

In [None]:
import timm
import torch
from lora import LoRA_ViT_timm_x

# we only support multi-melo with timm ViT models
melo_paths = ["path to melo weight 1", "path to melo weight 2"]

model = timm.create_model(weightInfo["base"], pretrained=True)
melo = LoRA_ViT_timm_x(model, melo_paths)

task_index = 0
img = torch.randn(1, 3, 224, 224)
melo.swith_lora(task_index)
melos_out = melo(img)