In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from LoRA import LoRA_Linear
from transformers import DistilBertTokenizer, DistilBertModel

**Load pretrained model**

In [57]:
bert = DistilBertModel.from_pretrained("distilbert-base-uncased")

In [58]:
bert

DistilBertModel(
  (embeddings): Embeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (layer): ModuleList(
      (0-5): 6 x TransformerBlock(
        (attention): MultiHeadSelfAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Li

In [60]:
trainable_parameters = 0
for p in bert.parameters():
    if p.requires_grad:
        trainable_parameters += p.numel()
print("Trainable parameters:", trainable_parameters)

Trainable parameters: 66362880


**Store an output**

In [61]:
B,T = (10,256)
Vocsize = 30522
random_inputs = torch.randint(0, Vocsize, size=(B,T))
with torch.no_grad():
    output = bert(random_inputs)

In [62]:
output["last_hidden_state"].shape

torch.Size([10, 256, 768])

We will apply LoRA on the Query matrix of the first attention head

In [63]:
bert.transformer.layer[0].attention.q_lin.weight.shape

torch.Size([768, 768])

Create a pointer on this module, call it "replaced_layer"

In [64]:
replaced_layer = bert.transformer.layer[0].attention.q_lin

Create our LoRA layer built on top of the replaced layer

In [65]:
lora_layer = LoRA_Linear(replaced_layer, r=4)

Then replace it in the pretrained model

In [66]:
bert.transformer.layer[0].attention.q_lin = lora_layer

In [67]:
bert.transformer.layer[0].attention.q_lin

LoRA_Linear(
  (base_layer): Linear(in_features=768, out_features=768, bias=True)
  (B): Linear(in_features=768, out_features=4, bias=False)
  (A): Linear(in_features=4, out_features=768, bias=False)
)

Let's check that the model still works

In [68]:
with torch.no_grad():
    output2 = bert(random_inputs)

It should also output the same result of the pretrained model, since our B is initialized with zeros

In [69]:
torch.norm(output["last_hidden_state"] - output2["last_hidden_state"])

tensor(0.)

In [70]:
trainable_parameters = 0
for p in bert.parameters():
    if p.requires_grad:
        trainable_parameters += p.numel()
print("Trainable parameters:", trainable_parameters)

Trainable parameters: 65779200


Our trainable parameters decreased by 1 million