# LoRA Implementation with PyTorch

In [1]:
import torch
import torchvision.datasets as datasets
from torchvision.transforms import transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
_ = torch.manual_seed(0)

In [3]:
# We will be training a network to classify MNIST digits and then fine-tune the network on a particular digit on which it doesn't perform well.

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))])

# Load the MNIST Dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create the DataLoader for training 
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# define the device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [4]:
# Create the Neural Network to classify the digits, making it overly complicated to better show the power of LoRA

# create an overly expensive neural network to classify MNIST digits

class RichBoyNet(nn.Module):
  
  def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
    super(RichBoyNet, self).__init__()
    self.linear_1 = nn.Linear(28*28, hidden_size_1)
    self.linear_2 = nn.Linear(hidden_size_1, hidden_size_2)
    self.linear_3 = nn.Linear(hidden_size_2, 10)
    self.relu = nn.ReLU()
  
  def forward(self, img):
    x = img.view(-1, 28*28)
    x = self.relu(self.linear_1(x))
    x = self.relu(self.linear_2(x))
    x = self.linear_3(x)
    return x

net = RichBoyNet().to(device) 

In [5]:
# Train the network only for 1 epoch to simulate a complete general pre-training on the data

def train(train_loader, net, epochs=5, total_iterations_limit=None):
  cross_el = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
  
  total_iterations = 0
  
  for epoch in range(epochs):
    net.train()
    
    loss_sum = 0
    num_iterations = 0
    
    data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
    if total_iterations_limit is not None:
      data_iterator.total = total_iterations_limit
    for data in data_iterator:
      num_iterations += 1
      total_iterations += 1
      x, y = data
      x = x.to(device)
      y = y.to(device)
      optimizer.zero_grad()
      output = net(x.view(-1, 28*28))
      loss = cross_el(output, y)
      loss_sum += loss.item()
      avg_loss = loss_sum / num_iterations
      data_iterator.set_postfix(loss=avg_loss)
      loss.backward()
      optimizer.step()
      
      if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
        return

train(train_loader, net, epochs=1)

Epoch 1: 100%|██████████| 6000/6000 [00:32<00:00, 184.61it/s, loss=0.237]


In [6]:
# keep a copy of the original weights (cloning them) so later we can prove that a fine-tuning with LoRA doesn't alter the original weights

original_weights = {}
for name, param in net.named_parameters():
  original_weights[name] = param.clone().detach()

In [7]:
# the performance of the pretrained network. As we can see, the network performs poorly on the digit 9. Let's fine tune it on the digit 9

def test():
  correct = 0
  total = 0
  
  wrong_counts = [0 for i in range(10)]
  
  with torch.no_grad():
    for data in tqdm(test_loader, desc='Testing'):
      x, y = data
      x = x.to(device)
      y = y.to(device)
      output = net(x.view(-1, 784))
      for idx, i in enumerate(output):
        if torch.argmax(i) == y[idx]:
          correct += 1
        else:
          wrong_counts[y[idx]] += 1
        total += 1
  
  print(f'Accuracy: {round(correct/total, 3)}')
  for i in range(len(wrong_counts)):
    print(f'wrong counts for digit {i}: {wrong_counts[i]}')

test()

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 254.98it/s]

Accuracy: 0.959
wrong counts for digit 0: 11
wrong counts for digit 1: 11
wrong counts for digit 2: 21
wrong counts for digit 3: 76
wrong counts for digit 4: 25
wrong counts for digit 5: 46
wrong counts for digit 6: 31
wrong counts for digit 7: 33
wrong counts for digit 8: 24
wrong counts for digit 9: 137





In [8]:
# Let's visualize how many parameters are in the original network, before introducing LoRA matrices.
# Print the size of the weights matrices of the network
# Save the count of the total number of parameters

total_parameters_original = 0
for index, layer in enumerate([net.linear_1, net.linear_2, net.linear_3]):
  total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
  print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')
print(f'Total number of parameters: {total_parameters_original:,}')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10])
Total number of parameters: 2,807,010


In [9]:
# Define the LoRA parameterization as described in the paper. The full detail on how PyTorch parameterizations work is here: 
# https://pytorch.org/tutorials/intermediate/parametrizations.html

class LoRAParametrization(nn.Module):
  def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
    super().__init__()
    # Section 4.1 of the paper: 
    #   We use a random Gaussian initialization for A and zero for B, so ∆W = BA is zero at the beginning of training
    self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
    self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
    nn.init.normal_(self.lora_A, mean=0, std=1)
    
    # Section 4.1 of the paper: 
    #   We then scale ∆Wx by α/r , where α is a constant in r. 
    #   When optimizing with Adam, tuning α is roughly the same as tuning the learning rate if we scale the initialization appropriately. 
    #   As a result, we simply set α to the first r we try and do not tune it. 
    #   This scaling helps to reduce the need to retune hyperparameters when we vary r.
    self.scale = alpha / rank
    self.enabled = True

  def forward(self, original_weights):
    if self.enabled:
      # Return W + (B*A)*scale
      return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
    else:
      return original_weights

In [10]:
# Add the parameterization to our network.

import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
  # Only add the parameterization to the weight matrix, ignore the Bias

  # From section 4.2 of the paper:
  #   We limit our study to only adapting the attention weights for downstream tasks and freeze the MLP modules (so they are not trained in downstream tasks) both for simplicity and parameter-efficiency.
  #   [...]
  #   We leave the empirical investigation of [...], and biases to a future work.
  
  features_in, features_out = layer.weight.shape
  return LoRAParametrization(
      features_in, features_out, rank=rank, alpha=lora_alpha, device=device
  )

parametrize.register_parametrization(
    net.linear_1, "weight", linear_layer_parameterization(net.linear_1, device)
)
parametrize.register_parametrization(
    net.linear_2, "weight", linear_layer_parameterization(net.linear_2, device)
)
parametrize.register_parametrization(
    net.linear_3, "weight", linear_layer_parameterization(net.linear_3, device)
)


def enable_disable_lora(enabled=True):
    for layer in [net.linear_1, net.linear_2, net.linear_3]:
        layer.parametrizations["weight"][0].enabled = enabled

In [11]:
# Display the number of parameters added by LoRA.

total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([net.linear_1, net.linear_2, net.linear_3]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
    )
# The non-LoRA parameters count must match the original network
assert total_parameters_non_lora == total_parameters_original
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000]) + Lora_A: torch.Size([1, 784]) + Lora_B: torch.Size([1000, 1])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000]) + Lora_A: torch.Size([1, 1000]) + Lora_B: torch.Size([2000, 1])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10]) + Lora_A: torch.Size([1, 2000]) + Lora_B: torch.Size([10, 1])
Total number of parameters (original): 2,807,010
Total number of parameters (original + LoRA): 2,813,804
Parameters introduced by LoRA: 6,794
Parameters incremment: 0.242%


In [12]:
# Freeze all the parameters of the original network and only fine tuning the ones introduced by LoRA. 
# Then fine-tune the model on the digit 9 and only for 100 batches.

# Freeze the non-Lora parameters
for name, param in net.named_parameters():
  if 'lora' not in name:
    print(f'Freezing non-LoRA parameter {name}')
    param.requires_grad = False

# Load the MNIST dataset again, by keeping only the digit 9
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Train the network with LoRA only on the digit 9 and only for 100 batches (hoping that it would improve the performance on the digit 9)
train(train_loader, net, epochs=1, total_iterations_limit=100)

Freezing non-LoRA parameter linear_1.bias
Freezing non-LoRA parameter linear_1.parametrizations.weight.original
Freezing non-LoRA parameter linear_2.bias
Freezing non-LoRA parameter linear_2.parametrizations.weight.original
Freezing non-LoRA parameter linear_3.bias
Freezing non-LoRA parameter linear_3.parametrizations.weight.original


Epoch 1:  99%|█████████▉| 99/100 [00:00<00:00, 126.35it/s, loss=0.0803]


In [13]:
# Verify that the fine-tuning didn't alter the original weights, but only the ones introduced by LoRA.

# Check that the frozen parameters are still unchanged by the finetuning
assert torch.all(net.linear_1.parametrizations.weight.original == original_weights['linear_1.weight'])
assert torch.all(net.linear_2.parametrizations.weight.original == original_weights['linear_2.weight'])
assert torch.all(net.linear_3.parametrizations.weight.original == original_weights['linear_3.weight'])

enable_disable_lora(enabled=True)
# The new linear1.weight is obtained by the "forward" function of our LoRA parametrization
# The original weights have been moved to net.linear1.parametrizations.weight.original
# More info here: https://pytorch.org/tutorials/intermediate/parametrizations.html#inspecting-a-parametrized-module
assert torch.equal(net.linear_1.weight, net.linear_1.parametrizations.weight.original + (net.linear_1.parametrizations.weight[0].lora_B @ net.linear_1.parametrizations.weight[0].lora_A) * net.linear_1.parametrizations.weight[0].scale)

enable_disable_lora(enabled=False)
# If we disable LoRA, the linear_1.weight is the original one
assert torch.equal(net.linear_1.weight, original_weights['linear_1.weight'])

In [14]:
# Test the network with LoRA enabled (the digit 9 should be classified better)
# Test with LoRA enabled
enable_disable_lora(enabled=True)
test()

Testing: 100%|██████████| 1000/1000 [00:04<00:00, 206.74it/s]

Accuracy: 0.88
wrong counts for digit 0: 12
wrong counts for digit 1: 14
wrong counts for digit 2: 37
wrong counts for digit 3: 249
wrong counts for digit 4: 229
wrong counts for digit 5: 97
wrong counts for digit 6: 28
wrong counts for digit 7: 341
wrong counts for digit 8: 180
wrong counts for digit 9: 8





In [15]:
# Test the network with LoRA disabled (the accuracy and errors counts must be the same as the original network)

# Test with LoRA disabled
enable_disable_lora(enabled=False)
test()

Testing: 100%|██████████| 1000/1000 [00:04<00:00, 214.70it/s]

Accuracy: 0.959
wrong counts for digit 0: 11
wrong counts for digit 1: 11
wrong counts for digit 2: 21
wrong counts for digit 3: 76
wrong counts for digit 4: 25
wrong counts for digit 5: 46
wrong counts for digit 6: 31
wrong counts for digit 7: 33
wrong counts for digit 8: 24
wrong counts for digit 9: 137



