In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from LoRA import LoRA_Linear
from transformers import DistilBertTokenizer, DistilBertModel

**Load pretrained model**

In [2]:
bert = DistilBertModel.from_pretrained("distilbert-base-uncased")

In [3]:
bert

DistilBertModel(
  (embeddings): Embeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (layer): ModuleList(
      (0-5): 6 x TransformerBlock(
        (attention): MultiHeadSelfAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Li

In [4]:
trainable_parameters = 0
for p in bert.parameters():
    if p.requires_grad:
        trainable_parameters += p.numel()
print("Trainable parameters:", trainable_parameters)

Trainable parameters: 66362880


**Store an output**

In [5]:
B,T = (10,256)
Vocsize = 30522
random_inputs = torch.randint(0, Vocsize, size=(B,T))
with torch.no_grad():
    output = bert(random_inputs)

In [6]:
output["last_hidden_state"].shape

torch.Size([10, 256, 768])

We will apply LoRA on the Query matrix of the first attention head

In [7]:
bert.transformer.layer[0].attention.q_lin.weight.shape

torch.Size([768, 768])

Create a pointer on this module, call it "replaced_layer"

In [8]:
replaced_layer = bert.transformer.layer[0].attention.q_lin

Create our LoRA layer built on top of the replaced layer

In [9]:
lora_layer = LoRA_Linear(replaced_layer, r=4)

Then replace it in the pretrained model

In [10]:
bert.transformer.layer[0].attention.q_lin = lora_layer

In [11]:
bert.transformer.layer[0].attention.q_lin

LoRA_Linear(
  (base_layer): Linear(in_features=768, out_features=768, bias=True)
  (B): Linear(in_features=768, out_features=4, bias=False)
  (A): Linear(in_features=4, out_features=768, bias=False)
)

Let's check that the model still works

In [12]:
with torch.no_grad():
    output2 = bert(random_inputs)

It should also output the same result of the pretrained model, since our B is initialized with zeros

In [13]:
torch.norm(output["last_hidden_state"] - output2["last_hidden_state"])

tensor(0.)

In [14]:
trainable_parameters = 0
for p in bert.parameters():
    if p.requires_grad:
        trainable_parameters += p.numel()
print("Trainable parameters:", trainable_parameters)

Trainable parameters: 65779200


Our trainable parameters decreased by 1 million

Let's reload the model, and apply LoRA on every $W_q$ $W_v$ weight matrices

In [109]:
bert = DistilBertModel.from_pretrained("distilbert-base-uncased")

In [110]:
bert

DistilBertModel(
  (embeddings): Embeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (layer): ModuleList(
      (0-5): 6 x TransformerBlock(
        (attention): MultiHeadSelfAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Li

In [45]:
bert.get_submodule("transformer.layer.0")

TransformerBlock(
  (attention): MultiHeadSelfAttention(
    (dropout): Dropout(p=0.1, inplace=False)
    (q_lin): Linear(in_features=768, out_features=768, bias=True)
    (k_lin): Linear(in_features=768, out_features=768, bias=True)
    (v_lin): Linear(in_features=768, out_features=768, bias=True)
    (out_lin): Linear(in_features=768, out_features=768, bias=True)
  )
  (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (ffn): FFN(
    (dropout): Dropout(p=0.1, inplace=False)
    (lin1): Linear(in_features=768, out_features=3072, bias=True)
    (lin2): Linear(in_features=3072, out_features=768, bias=True)
    (activation): GELUActivation()
  )
  (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)

In [59]:
transformer_block = bert.get_submodule("transformer.layer.0")
[mod for mod in transformer_block.named_modules()]

[('',
  TransformerBlock(
    (attention): MultiHeadSelfAttention(
      (dropout): Dropout(p=0.1, inplace=False)
      (q_lin): Linear(in_features=768, out_features=768, bias=True)
      (k_lin): Linear(in_features=768, out_features=768, bias=True)
      (v_lin): Linear(in_features=768, out_features=768, bias=True)
      (out_lin): Linear(in_features=768, out_features=768, bias=True)
    )
    (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (ffn): FFN(
      (dropout): Dropout(p=0.1, inplace=False)
      (lin1): Linear(in_features=768, out_features=3072, bias=True)
      (lin2): Linear(in_features=3072, out_features=768, bias=True)
      (activation): GELUActivation()
    )
    (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  )),
 ('attention',
  MultiHeadSelfAttention(
    (dropout): Dropout(p=0.1, inplace=False)
    (q_lin): Linear(in_features=768, out_features=768, bias=True)
    (k_lin): Linear(in_features=768, out_features=76

In [52]:
q_lin = [bert.get_submodule(f"transformer.layer.{i}.attention.q_lin") for i in range(6)]
v_lin = [bert.get_submodule(f"transformer.layer.{i}.attention.v_lin") for i in range(6)]
k_lin = [bert.get_submodule(f"transformer.layer.{i}.attention.k_lin") for i in range(6)]
out_lin = [bert.get_submodule(f"transformer.layer.{i}.attention.out_lin") for i in range(6)]

In [128]:
[name for name in bert.embeddings.named_children()]

[('word_embeddings', Embedding(30522, 768, padding_idx=0)),
 ('position_embeddings', Embedding(512, 768)),
 ('LayerNorm', LayerNorm((768,), eps=1e-12, elementwise_affine=True)),
 ('dropout', Dropout(p=0.1, inplace=False))]

In [143]:
bert.transformer.layer[0].attention.q_lin

Linear(in_features=768, out_features=768, bias=True)

In [145]:
def apply_LoRA(layer_types:list[str], r):
    'layer_types = ["q_lin", "out_lin", "v_lin", "k_lin"]'
    bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
    for param in bert.parameters():
        param.requires_grad = False
    modules = dict()
    for layer_type in layer_types:
        for i in range(6):
            if layer_type == "q_lin":
                target_layer = bert.get_submodule(f"transformer.layer.{i}.attention.{layer_type}")
                target_layer_to_LoRA = LoRA_Linear(target_layer, r=r)
                bert.transformer.layer[i].attention.q_lin = target_layer_to_LoRA

    return bert

def get_trainable_parameters(module:nn.Module):
    count = 0
    for param in module.parameters():
        if param.requires_grad:
            count += param.numel()
    print("Trainable parameters: %d"%count)

In [146]:
bert_LoRA = apply_LoRA(layer_types=["q_lin"], r=2)

In [147]:
get_trainable_parameters(bert_LoRA)

Trainable parameters: 18432


In [150]:
bert_LoRA.transformer.layer[5]

TransformerBlock(
  (attention): MultiHeadSelfAttention(
    (dropout): Dropout(p=0.1, inplace=False)
    (q_lin): LoRA_Linear(
      (base_layer): Linear(in_features=768, out_features=768, bias=True)
      (B): Linear(in_features=768, out_features=2, bias=False)
      (A): Linear(in_features=2, out_features=768, bias=False)
    )
    (k_lin): Linear(in_features=768, out_features=768, bias=True)
    (v_lin): Linear(in_features=768, out_features=768, bias=True)
    (out_lin): Linear(in_features=768, out_features=768, bias=True)
  )
  (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (ffn): FFN(
    (dropout): Dropout(p=0.1, inplace=False)
    (lin1): Linear(in_features=768, out_features=3072, bias=True)
    (lin2): Linear(in_features=3072, out_features=768, bias=True)
    (activation): GELUActivation()
  )
  (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)

In [107]:
bert_LoRA.transformer.layer[0].attention.q_lin

Linear(in_features=768, out_features=768, bias=True)