#### Basic LoRA layer

In [1]:
import torch
import torch.nn as nn

class LoRALayer(nn.Module):
    def __init__(self, base_layer, rank):
        super(LoRALayer, self).__init__()
        
        self.base_layer = base_layer
        self.rank = rank

        # initialize the low-rank parameters
        self.A = nn.Parameter(torch.randn(base_layer.out_features, rank)) #random gauss as per paper
        self.B = nn.Parameter(torch.zeros(rank, base_layer.in_features))  #zeros as per paper

        # freeze the original layer
        for param in self.base_layer.parameters():
            param.requires_grad = False

    def forward(self, x):
        #original output
        out = self.base_layer(x)

        #lora adaptation
        lora_output = self.B.to(x.device) @ x.transpose(1, 2) # B * x accounting for batch dimension
        lora_output = self.A.to(x.device) @ lora_output # A * B * x
        lora_output = lora_output.transpose(1, 2) # back to original shape

        return out + lora_output

#### LoRA Bert class for demo

In [2]:
class LoRABert(nn.Module):
    def __init__(self, model, rank, num_classes):
        super(LoRABert, self).__init__()

        self.model = model
        
        #classifier head for SST-2
        self.classifier = nn.Linear(model.config.hidden_size, num_classes)

        #adapt layers
        for layer in self.model.encoder.layer:
            # adapt attention layers
            layer.attention.self.query = LoRALayer(layer.attention.self.query, rank)
            layer.attention.self.key = LoRALayer(layer.attention.self.key, rank)
            layer.attention.self.value = LoRALayer(layer.attention.self.value, rank)
            layer.attention.output.dense = LoRALayer(layer.attention.output.dense, rank)

            # adapt feedforward layers
            layer.intermediate.dense = LoRALayer(layer.intermediate.dense, rank)
            layer.output.dense = LoRALayer(layer.output.dense, rank)

    def forward(self, input_ids, attention_mask):
        outs = self.model(input_ids, attention_mask)
        return self.classifier(outs[1])

#### Demo
We'll use LoRA to finetune a pretrained BERT on the SST-2 sentiment dataset.

In [3]:
from datasets import load_dataset

dataset = load_dataset("glue", "sst2")
train_dataset = dataset["train"]
val_dataset = dataset["validation"]

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from transformers import BertTokenizer, BertModel

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
lora_model = LoRABert(model, 4, 2)




In [5]:
def tokenize(batch):
    return tokenizer(batch['sentence'], padding="max_length", truncation=True)

train_dataset = train_dataset.map(tokenize, batched=True)
val_dataset = val_dataset.map(tokenize, batched=True)

In [6]:
from torch.utils.data import DataLoader

train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32)

#### Train loop
Despite using LoRA, this does take 12 minutes per epoch on my 4090. I'm not sure if that's just to expected since the dataset has ~65k elements and BERT is a moderately large sized model, or if there is some optimization I'm missing.

In [7]:
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
lora_model.to(device)

NUM_EPOCHS = 10
optimizer = torch.optim.Adam(lora_model.parameters(), lr=5e-5)
loss_function = nn.CrossEntropyLoss()

for epoch in range(NUM_EPOCHS):
    lora_model.train()
    total_loss = 0

    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        outputs = lora_model(input_ids, attention_mask)
        loss = loss_function(outputs, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    avg_train_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch} train loss: {avg_train_loss}")




cuda


100%|██████████| 2105/2105 [11:17<00:00,  3.11it/s]


Epoch 0 train loss: 0.22254993065864273


100%|██████████| 2105/2105 [11:17<00:00,  3.11it/s]


Epoch 1 train loss: 0.12443592937977326


  1%|▏         | 27/2105 [00:09<11:40,  2.97it/s]


KeyboardInterrupt: 