# GPT2 with LoRA implementation in PyTorch

Let's start by importing the necessary libraries

In [6]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoConfig
from transformers import AutoTokenizer
import os

Make the model deterministic

In [13]:
# Make torch deterministic
_ = torch.manual_seed(0)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [15]:
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    fast_tokenizer=True)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
config = AutoConfig.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    config=config,
).to(device)



In [16]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

NameError: name 'net' is not defined

In [17]:
class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
        super().__init__()
        # Section 4.1 of the paper:
        #   We use a random Gaussian initialization for A and zero for B, so ∆W = BA is zero at the beginning of training
        self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)

        # Section 4.1 of the paper:
        #   We then scale ∆Wx by α/r , where α is a constant in r.
        #   When optimizing with Adam, tuning α is roughly the same as tuning the learning rate if we scale the initialization appropriately.
        #   As a result, we simply set α to the first r we try and do not tune it.
        #   This scaling helps to reduce the need to retune hyperparameters when we vary r.
        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            # Return W + (B*A)*scale
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

In [31]:
target_names = []
for name, module in model.named_modules():
    if "attn.c_attn" in name:
        target_names.append(name)



In [34]:
target_names

['transformer.h.0.attn.c_attn',
 'transformer.h.1.attn.c_attn',
 'transformer.h.2.attn.c_attn',
 'transformer.h.3.attn.c_attn',
 'transformer.h.4.attn.c_attn',
 'transformer.h.5.attn.c_attn',
 'transformer.h.6.attn.c_attn',
 'transformer.h.7.attn.c_attn',
 'transformer.h.8.attn.c_attn',
 'transformer.h.9.attn.c_attn',
 'transformer.h.10.attn.c_attn',
 'transformer.h.11.attn.c_attn']

In [38]:
target_modules = [model.transformer.h[0].attn.c_attn,
 model.transformer.h[1].attn.c_attn,
 model.transformer.h[2].attn.c_attn,
 model.transformer.h[3].attn.c_attn,
 model.transformer.h[4].attn.c_attn,
 model.transformer.h[5].attn.c_attn,
 model.transformer.h[6].attn.c_attn,
 model.transformer.h[7].attn.c_attn,
 model.transformer.h[8].attn.c_attn,
 model.transformer.h[9].attn.c_attn,
 model.transformer.h[10].attn.c_attn,
 model.transformer.h[11].attn.c_attn]

In [39]:
target_modules

[Conv1D(),
 Conv1D(),
 Conv1D(),
 Conv1D(),
 Conv1D(),
 Conv1D(),
 Conv1D(),
 Conv1D(),
 Conv1D(),
 Conv1D(),
 Conv1D(),
 Conv1D()]

In [43]:
# Print the size of the weights matrices of the network
# Save the count of the total number of parameters
total_parameters_original = 0
for index, layer in enumerate(target_modules):
    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')
print(f'Total number of parameters: {total_parameters_original:,}')

Layer 1: W: torch.Size([768, 2304]) + B: torch.Size([2304])
Layer 2: W: torch.Size([768, 2304]) + B: torch.Size([2304])
Layer 3: W: torch.Size([768, 2304]) + B: torch.Size([2304])
Layer 4: W: torch.Size([768, 2304]) + B: torch.Size([2304])
Layer 5: W: torch.Size([768, 2304]) + B: torch.Size([2304])
Layer 6: W: torch.Size([768, 2304]) + B: torch.Size([2304])
Layer 7: W: torch.Size([768, 2304]) + B: torch.Size([2304])
Layer 8: W: torch.Size([768, 2304]) + B: torch.Size([2304])
Layer 9: W: torch.Size([768, 2304]) + B: torch.Size([2304])
Layer 10: W: torch.Size([768, 2304]) + B: torch.Size([2304])
Layer 11: W: torch.Size([768, 2304]) + B: torch.Size([2304])
Layer 12: W: torch.Size([768, 2304]) + B: torch.Size([2304])
Total number of parameters: 21,261,312


In [40]:
import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    # Only add the parameterization to the weight matrix, ignore the Bias

    # From section 4.2 of the paper:
    #   We limit our study to only adapting the attention weights for downstream tasks and freeze the MLP modules (so they are not trained in downstream tasks) both for simplicity and parameter-efficiency.
    #   [...]
    #   We leave the empirical investigation of [...], and biases to a future work.

    features_in, features_out = layer.weight.shape
    return LoRAParametrization(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )

for target_module in target_modules:
    parametrize.register_parametrization(
        target_module, "weight", linear_layer_parameterization(target_module, device)
    )



def enable_disable_lora(enabled=True):
    for layer in target_modules:
        layer.parametrizations["weight"][0].enabled = enabled

In [44]:
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate(target_modules):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
    )
# The non-LoRA parameters count must match the original network
assert total_parameters_non_lora == total_parameters_original
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')

Layer 1: W: torch.Size([768, 2304]) + B: torch.Size([2304]) + Lora_A: torch.Size([1, 2304]) + Lora_B: torch.Size([768, 1])
Layer 2: W: torch.Size([768, 2304]) + B: torch.Size([2304]) + Lora_A: torch.Size([1, 2304]) + Lora_B: torch.Size([768, 1])
Layer 3: W: torch.Size([768, 2304]) + B: torch.Size([2304]) + Lora_A: torch.Size([1, 2304]) + Lora_B: torch.Size([768, 1])
Layer 4: W: torch.Size([768, 2304]) + B: torch.Size([2304]) + Lora_A: torch.Size([1, 2304]) + Lora_B: torch.Size([768, 1])
Layer 5: W: torch.Size([768, 2304]) + B: torch.Size([2304]) + Lora_A: torch.Size([1, 2304]) + Lora_B: torch.Size([768, 1])
Layer 6: W: torch.Size([768, 2304]) + B: torch.Size([2304]) + Lora_A: torch.Size([1, 2304]) + Lora_B: torch.Size([768, 1])
Layer 7: W: torch.Size([768, 2304]) + B: torch.Size([2304]) + Lora_A: torch.Size([1, 2304]) + Lora_B: torch.Size([768, 1])
Layer 8: W: torch.Size([768, 2304]) + B: torch.Size([2304]) + Lora_A: torch.Size([1, 2304]) + Lora_B: torch.Size([768, 1])
Layer 9: W: torc

In [41]:
# Freeze the non-Lora parameters
for name, param in model.named_parameters():
  print(name)
    # if 'lora' not in name:
    #     print(f'Freezing non-LoRA parameter {name}')
    #     param.requires_grad = False

transformer.wte.weight
transformer.wpe.weight
transformer.h.0.ln_1.weight
transformer.h.0.ln_1.bias
transformer.h.0.attn.c_attn.bias
transformer.h.0.attn.c_attn.parametrizations.weight.original
transformer.h.0.attn.c_attn.parametrizations.weight.0.lora_A
transformer.h.0.attn.c_attn.parametrizations.weight.0.lora_B
transformer.h.0.attn.c_proj.weight
transformer.h.0.attn.c_proj.bias
transformer.h.0.ln_2.weight
transformer.h.0.ln_2.bias
transformer.h.0.mlp.c_fc.weight
transformer.h.0.mlp.c_fc.bias
transformer.h.0.mlp.c_proj.weight
transformer.h.0.mlp.c_proj.bias
transformer.h.1.ln_1.weight
transformer.h.1.ln_1.bias
transformer.h.1.attn.c_attn.bias
transformer.h.1.attn.c_attn.parametrizations.weight.original
transformer.h.1.attn.c_attn.parametrizations.weight.0.lora_A
transformer.h.1.attn.c_attn.parametrizations.weight.0.lora_B
transformer.h.1.attn.c_proj.weight
transformer.h.1.attn.c_proj.bias
transformer.h.1.ln_2.weight
transformer.h.1.ln_2.bias
transformer.h.1.mlp.c_fc.weight
transforme

Keep a copy of the original weights (cloning them) so later we can prove that a fine-tuning with LoRA doesn't alter the original weights

In [None]:
original_weights = {}
for name, param in net.named_parameters():
    original_weights[name] = param.clone().detach()


Let's visualize how many parameters are in the original network, before introducing the LoRA matrices.

In [None]:
# Print the size of the weights matrices of the network
# Save the count of the total number of parameters
total_parameters_original = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')
print(f'Total number of parameters: {total_parameters_original:,}')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10])
Total number of parameters: 2,807,010


Define the LoRA parameterization as described in the paper.
The full detail on how PyTorch parameterizations work is here: https://pytorch.org/tutorials/intermediate/parametrizations.html

Add the parameterization to our network.

Display the number of parameters added by LoRA.

In [None]:
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
    )
# The non-LoRA parameters count must match the original network
assert total_parameters_non_lora == total_parameters_original
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000]) + Lora_A: torch.Size([1, 784]) + Lora_B: torch.Size([1000, 1])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000]) + Lora_A: torch.Size([1, 1000]) + Lora_B: torch.Size([2000, 1])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10]) + Lora_A: torch.Size([1, 2000]) + Lora_B: torch.Size([10, 1])
Total number of parameters (original): 2,807,010
Total number of parameters (original + LoRA): 2,813,804
Parameters introduced by LoRA: 6,794
Parameters incremment: 0.242%


Freeze all the parameters of the original network and only fine tuning the ones introduced by LoRA. Then fine-tune the model on the digit 9 and only for 100 batches.

In [None]:
# Freeze the non-Lora parameters
for name, param in net.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

# Load the MNIST dataset again, by keeping only the digit 9
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Train the network with LoRA only on the digit 9 and only for 100 batches (hoping that it would improve the performance on the digit 9)
train(train_loader, net, epochs=1, total_iterations_limit=100)

Freezing non-LoRA parameter linear1.bias
Freezing non-LoRA parameter linear1.parametrizations.weight.original
Freezing non-LoRA parameter linear2.bias
Freezing non-LoRA parameter linear2.parametrizations.weight.original
Freezing non-LoRA parameter linear3.bias
Freezing non-LoRA parameter linear3.parametrizations.weight.original


Epoch 1:   0%|          | 0/100 [00:00<?, ?it/s, loss=0.188] 

Epoch 1:  99%|█████████▉| 99/100 [00:00<00:00, 238.78it/s, loss=0.089] 


Verify that the fine-tuning didn't alter the original weights, but only the ones introduced by LoRA.

In [None]:
# Check that the frozen parameters are still unchanged by the finetuning
assert torch.all(net.linear1.parametrizations.weight.original == original_weights['linear1.weight'])
assert torch.all(net.linear2.parametrizations.weight.original == original_weights['linear2.weight'])
assert torch.all(net.linear3.parametrizations.weight.original == original_weights['linear3.weight'])

enable_disable_lora(enabled=True)
# The new linear1.weight is obtained by the "forward" function of our LoRA parametrization
# The original weights have been moved to net.linear1.parametrizations.weight.original
# More info here: https://pytorch.org/tutorials/intermediate/parametrizations.html#inspecting-a-parametrized-module
assert torch.equal(net.linear1.weight, net.linear1.parametrizations.weight.original + (net.linear1.parametrizations.weight[0].lora_B @ net.linear1.parametrizations.weight[0].lora_A) * net.linear1.parametrizations.weight[0].scale)

enable_disable_lora(enabled=False)
# If we disable LoRA, the linear1.weight is the original one
assert torch.equal(net.linear1.weight, original_weights['linear1.weight'])

Test the network with LoRA enabled (the digit 9 should be classified better)

Test the network with LoRA disabled (the accuracy and errors counts must be the same as the original network)