# LoRA Implementation with PyTorch

In [1]:
# imports
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
from tqdm import tqdm

In [2]:
# Making torch deterministic
_ = torch.manual_seed(10)

We will be training a model to classify MNIST digits and then fine-tune the model on a particular digit on which it doesn't perform well

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))])

# Loading the MNIST Dataset
mnist_trainset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
# Dataloader
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Loading the MNIST test set
mnist_testset =  datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# Setting Device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Creating a NN to classify the digits. Will be making it overly complicated with very high parameters to show the power of LoRA

In [4]:
class BahutDangorNet(nn.Module):
  def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
    super(BahutDangorNet, self).__init__()
    self.linear1 = nn.Linear(28*28, hidden_size_1)
    self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
    self.linear3 = nn.Linear(hidden_size_2, 10)
    self.relu = nn.ReLU()

  def forward(self, img):
    x = img.view(-1, 28*28)
    x = self.relu(self.linear1(x))
    x = self.relu(self.linear2(x))
    x = self.linear3(x)

    return x

model = BahutDangorNet().to(device)

Training for only 1 Epoch to simulate a general pre-training on the data

In [5]:
def train(train_loader, model, epochs=5, total_iterations_limit=None):
  cross_el = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

  total_iterations = 0

  for epoch in range(epochs):
    model.train()

    loss_sum = 0
    num_iterations = 0

    data_iter = tqdm(train_loader, desc=f"Epoch {epoch+1}")
    if total_iterations_limit is not None:
      data_iter.total = total_iterations_limit

    for data in data_iter:
      num_iterations += 1
      total_iterations += 1
      x, y = data
      x = x.to(device)
      y = y.to(device)

      optimizer.zero_grad()

      out = model(x.view(-1, 28*28))
      # print(out, y)
      loss = cross_el(out, y)

      loss_sum += loss.item()
      avg_loss = loss_sum / num_iterations
      data_iter.set_postfix(loss=avg_loss)

      loss.backward()
      optimizer.step()

      if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
        return

In [6]:
train(train_loader, model, epochs=1)

Epoch 1: 100%|██████████| 6000/6000 [00:53<00:00, 112.51it/s, loss=0.246]


Testing the performance of the pretrained model. Observe that the model performs poorly on the digit 8, so we will fine tune it on digit 8

In [7]:
def test():
  correct = 0
  total = 0

  wrong_counts = [0 for i in range(10)]

  with torch.no_grad():
    for data in tqdm(test_loader, desc='Testing'):
      x, y = data
      x = x.to(device)
      y = y.to(device)
      out = model(x.view(-1, 28*28))

      for idx, i in enumerate(out):
        if torch.argmax(i) == y[idx]:
          correct += 1
        else:
          wrong_counts[y[idx]] += 1
        total += 1

    print(f'\nAccuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_counts)):
      print(f'wrong counts for the digit {i}: {wrong_counts[i]}')

test()

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 256.97it/s]


Accuracy: 0.963
wrong counts for the digit 0: 19
wrong counts for the digit 1: 14
wrong counts for the digit 2: 52
wrong counts for the digit 3: 32
wrong counts for the digit 4: 32
wrong counts for the digit 5: 47
wrong counts for the digit 6: 20
wrong counts for the digit 7: 32
wrong counts for the digit 8: 69
wrong counts for the digit 9: 52





Let's see how many parameters are in the original model, before introducing the LoRA matrices

In [8]:
total_parameters_original = 0
for index, layer in enumerate([model.linear1, model.linear2, model.linear3]):
    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')
print(f'Total number of parameters: {total_parameters_original:,}')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10])
Total number of parameters: 2,807,010


Defining the LoRA parametrization as described in the paper

In [9]:
class LoRAParametrization(nn.Module):
  def __init__(self, features_in, features_out, rank=1, alpha=1, device="cpu"):
    super().__init__()
    # Section 4.1 of the paper:
    # We use a random Gaussian initialization for A and zero for B, so ∆W = BA is zero at the beginning of training
    self.lora_A = nn.Parameter(torch.zeros((rank, features_out)).to(device))
    self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
    nn.init.normal_(self.lora_A, mean=0, std=1)

    # Section 4.1 of the paper:
    # We then scale ∆Wx by α/r , where α is a constant in r.
    # When optimizing with Adam, tuning α is roughly the same as tuning the learning rate if we scale the initialization appropriately
    # As a result, we simply set α to the first r we try and do not tune it.
    # This scaling helps to reduce the need to retune hyperparameters when we vary r.
    self.scale = alpha / rank
    self.enable = True

  def forward(self, original_weights):
    if self.enable:
      # Return W + (B*A)*scale
      return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
    else:
      return original_weights

Adding the parameterization to our model.

https://pytorch.org/tutorials/intermediate/parametrizations.html

In [10]:
import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    # Only add the parameterization to the weight matrix, ignore the Bias

    features_in, features_out = layer.weight.shape
    return LoRAParametrization(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )

parametrize.register_parametrization(
    model.linear1, "weight", linear_layer_parameterization(model.linear1, device)
)
parametrize.register_parametrization(
    model.linear2, "weight", linear_layer_parameterization(model.linear2, device)
)
parametrize.register_parametrization(
    model.linear3, "weight", linear_layer_parameterization(model.linear3, device)
)


def enable_disable_lora(enable=True):
    for layer in [model.linear1, model.linear2, model.linear3]:
        layer.parametrizations["weight"][0].enable = enable

Display the number of parameters added by LoRA

In [11]:
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([model.linear1, model.linear2, model.linear3]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
    )
# The non-LoRA parameters count must match the original network
assert total_parameters_non_lora == total_parameters_original
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000]) + Lora_A: torch.Size([1, 784]) + Lora_B: torch.Size([1000, 1])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000]) + Lora_A: torch.Size([1, 1000]) + Lora_B: torch.Size([2000, 1])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10]) + Lora_A: torch.Size([1, 2000]) + Lora_B: torch.Size([10, 1])
Total number of parameters (original): 2,807,010
Total number of parameters (original + LoRA): 2,813,804
Parameters introduced by LoRA: 6,794
Parameters incremment: 0.242%


Freeze all the parameters of the original model and only fine tuning the ones introduced by LoRA. Then fine-tune the model on the digit 8 and only for 100 batches.

In [14]:
# Freeze the non-Lora parameters
for name, param in model.named_parameters():
  if 'lora' not in name:
    print(f"Freezing non-LoRA parameters {name}")
    param.require_grad = False

# Load the MNIST dataset again, by keeping only the digit 8
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 8
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]

# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Train the network with LoRA only on the digit 8 and only for 100 batches
train(train_loader, model, epochs=1, total_iterations_limit=100)

Freezing non-LoRA parameters linear1.bias
Freezing non-LoRA parameters linear1.parametrizations.weight.original
Freezing non-LoRA parameters linear2.bias
Freezing non-LoRA parameters linear2.parametrizations.weight.original
Freezing non-LoRA parameters linear3.bias
Freezing non-LoRA parameters linear3.parametrizations.weight.original


Epoch 1:  99%|█████████▉| 99/100 [00:01<00:00, 61.12it/s, loss=0.174]


Test the model with LoRA enabled (the digit 8 should be classified better)

In [15]:
# Test with LoRA enabled
enable_disable_lora(enable=True)
test()

Testing: 100%|██████████| 1000/1000 [00:04<00:00, 217.46it/s]


Accuracy: 0.381
wrong counts for the digit 0: 239
wrong counts for the digit 1: 1135
wrong counts for the digit 2: 776
wrong counts for the digit 3: 401
wrong counts for the digit 4: 767
wrong counts for the digit 5: 781
wrong counts for the digit 6: 335
wrong counts for the digit 7: 755
wrong counts for the digit 8: 0
wrong counts for the digit 9: 1005



