<a href="https://colab.research.google.com/github/alif-munim/language-models/blob/main/lora/lora_classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# LoRA
Using LoRA to fine-tune on a single digit after training an MNIST classifier.

In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
# Make model deterministic
_ = torch.manual_seed(0)

In [None]:
# Load MNIST dataset and create dataloaders
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

In [27]:
# Create an over-parameterized, inefficient neural network for classification
class BigNet(nn.Module):
  def __init__(self, hidden_dim1=1000, hidden_dim2=2000):
    super(BigNet, self).__init__()
    self.linear1 = nn.Linear(28*28, hidden_dim1)
    self.linear2 = nn.Linear(hidden_dim1, hidden_dim2)
    self.linear3 = nn.Linear(hidden_dim2, 10)
    self.relu = nn.ReLU()

  def forward(self, img):
    x = img.view(-1, 28*28)
    x = self.relu(self.linear1(x))
    x = self.relu(self.linear2(x))
    x = self.linear3(x)
    return x

In [28]:
# Define training loop for model on MNIST

def train(model, train_loader, num_epochs, iter_limit=None):
  cross_el = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

  total_iters = 0
  for epoch in range(num_epochs):
    model.train()
    loss_sum = 0
    num_iters = 0

    data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
    if iter_limit is not None:
      data_iterator.total = iter_limit

    for data in data_iterator:
      num_iters += 1
      total_iters += 1

      x, y = data
      x = x.to(device)
      y = y.to(device)

      optimizer.zero_grad()
      output = model(x.view(-1, 28*28))
      loss = cross_el(output, y)
      loss_sum += loss.item()
      avg_loss = loss_sum / num_iters
      data_iterator.set_postfix(loss=avg_loss)

      loss.backward()
      optimizer.step()

      if iter_limit is not None and total_iters >= iter_limit:
        return

In [29]:
# Train model for one epoch to simulate large-scale pre-training

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = BigNet().to(device)

train(model, train_loader, num_epochs=1)

Epoch 1: 100%|██████████| 6000/6000 [00:45<00:00, 131.84it/s, loss=0.238]


In [35]:
# Clone original weights

original_weights = {}
for name, param in model.named_parameters():
  original_weights[name] = param.clone().detach()

In [36]:
# Test the performance of the pretrained model on the test dataset

def test(model, test_loader):
  correct = 0
  total = 0
  wrong_counts = [0 for i in range(10)]

  with torch.no_grad():
    for data in tqdm(test_loader, desc='Testing'):
      x, y = data
      x = x.to(device)
      y = y.to(device)
      output = model(x.view(-1, 28*28))

      for idx, i in enumerate(output):
        if torch.argmax(i) == y[idx]:
          correct += 1
        else:
          wrong_counts[y[idx]] += 1
        total += 1

    print(f'\nAccuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_counts)):
      print(f'Wrong counts for digit {i}: {wrong_counts[i]}')

test(model, test_loader)

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 317.07it/s]


Accuracy: 0.956
Wrong counts for digit 0: 9
Wrong counts for digit 1: 7
Wrong counts for digit 2: 32
Wrong counts for digit 3: 55
Wrong counts for digit 4: 74
Wrong counts for digit 5: 38
Wrong counts for digit 6: 53
Wrong counts for digit 7: 63
Wrong counts for digit 8: 62
Wrong counts for digit 9: 49





In [37]:
# Print the size of the model's weight matrices and total parameters
original_params = 0
for index, layer in enumerate([model.linear1, model.linear2, model.linear3]):
  original_params += layer.weight.nelement() + layer.bias.nelement()
  print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')
print(f'Total number of parameters: {original_params:,}')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10])
Total number of parameters: 2,807,010


In [38]:
model.linear1

Linear(in_features=784, out_features=1000, bias=True)

In [39]:
# Define LoRA parameterization as defined in the paper

class LoraParametrization(nn.Module):
  def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
    super().__init__()

    self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
    self.lora_A = nn.Parameter(torch.zeros((rank, features_out)).to(device))

    # Use a random gaussian for A and zero for B so ΔW = BA is zero initially
    nn.init.normal_(self.lora_A, mean=0, std=1)

    # Introduce a scaling term. Set alpha to the first r we try (in this case 1)
    # No need to tune it afterwards, even with different values for r
    self.scale = alpha / rank
    self.enabled = True

  def forward(self, original_weights):
    if self.enabled:
      return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
    else:
      return original_weights


In [40]:
# Add parametrization to the linear layers

import torch.nn.utils.parametrize as parametrize

def linear_layer_parametrization(layer, device, rank=1, lora_alpha=1):
  # Only add to weight matrix, not bias
  features_in, features_out = layer.weight.shape
  return LoraParametrization(
      features_in, features_out, rank=rank, alpha=lora_alpha, device=device
  )

parametrize.register_parametrization(
    model.linear1, "weight", linear_layer_parametrization(model.linear1, device)
)

parametrize.register_parametrization(
    model.linear2, "weight", linear_layer_parametrization(model.linear2, device)
)

parametrize.register_parametrization(
    model.linear3, "weight", linear_layer_parametrization(model.linear3, device)
)

def enable_disable_lora(enabled=True):
  for layer in [model.linear1, model.linear2, model.linear3]:
    layer.parametrizations["weight"][0].enabled = enabled

In [24]:
model.linear1

ParametrizedLinear(
  in_features=784, out_features=1000, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): LoraParametrization()
    )
  )
)

In [47]:
# Compare the total number of parameters added to the model by LoRA
lora_params = 0
non_lora_params = 0

for index, layer in enumerate([model.linear1, model.linear2, model.linear3]):
  lora_params += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
  non_lora_params += layer.weight.nelement() + layer.bias.nelement()
  print(
      f'Layer {index + 1}: \n\tW: {layer.weight.shape} \n\t+ B: {layer.bias.shape} \n\t+ LoRA_A: {layer.parametrizations["weight"][0].lora_A.shape} \n\t+ LoRA_B: {layer.parametrizations["weight"][0].lora_B.shape}'
  )

Layer 1: 
	W: torch.Size([1000, 784]) 
	+ B: torch.Size([1000]) 
	+ LoRA_A: torch.Size([1, 784]) 
	+ LoRA_B: torch.Size([1000, 1])
Layer 2: 
	W: torch.Size([2000, 1000]) 
	+ B: torch.Size([2000]) 
	+ LoRA_A: torch.Size([1, 1000]) 
	+ LoRA_B: torch.Size([2000, 1])
Layer 3: 
	W: torch.Size([10, 2000]) 
	+ B: torch.Size([10]) 
	+ LoRA_A: torch.Size([1, 2000]) 
	+ LoRA_B: torch.Size([10, 1])


In [49]:
assert non_lora_params == original_params
print(f'Total number of parameters (original): {non_lora_params:,}')
print(f'Total number of parameters (original + LoRA): {lora_params + non_lora_params:,}')
print(f'Parameters introduced by lora: {lora_params:,}')

param_increase = (lora_params / non_lora_params) * 100
print(f'Increase in parameters: {param_increase:.3f}%')

Total number of parameters (original): 2,807,010
Total number of parameters (original + LoRA): 2,813,804
Parameters introduced by lora: 6,794
Increase in parameters: 0.242%


In [57]:
# View named parameters in model
for name, param in model.named_parameters():
  print(name)

linear1.bias
linear1.parametrizations.weight.original
linear1.parametrizations.weight.0.lora_B
linear1.parametrizations.weight.0.lora_A
linear2.bias
linear2.parametrizations.weight.original
linear2.parametrizations.weight.0.lora_B
linear2.parametrizations.weight.0.lora_A
linear3.bias
linear3.parametrizations.weight.original
linear3.parametrizations.weight.0.lora_B
linear3.parametrizations.weight.0.lora_A


In [59]:
# Freeze non-LoRA network params

frozen_params = []
for name, param in model.named_parameters():
  if 'lora' not in name:
    print(f'Freezing non-LoRA parameter {name}...')
    param.requires_grad = False
    frozen_params.append(name)

Freezing non-LoRA parameter linear1.bias...
Freezing non-LoRA parameter linear1.parametrizations.weight.original...
Freezing non-LoRA parameter linear2.bias...
Freezing non-LoRA parameter linear2.parametrizations.weight.original...
Freezing non-LoRA parameter linear3.bias...
Freezing non-LoRA parameter linear3.parametrizations.weight.original...


In [60]:
# Only fine-tune LoRA params for digit 4 for 100 batches

mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 4
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]

train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)
train(model, train_loader, num_epochs=1, iter_limit=100)

Epoch 1:  99%|█████████▉| 99/100 [00:00<00:00, 128.43it/s, loss=0.0594]


In [None]:
# Check frozen params are the same after fine-tuning

assert torch.all(model.linear1.parametrizations.weight.original == original_weights['linear1.weight'])
assert torch.all(model.linear2.parametrizations.weight.original == original_weights['linear2.weight'])
assert torch.all(model.linear3.parametrizations.weight.original == original_weights['linear3.weight'])

In [61]:
# Check that pytorch replaces weight access by LoRA parametrization

enable_disable_lora(enabled=True)
old_weights = model.linear1.weight

scale_term = model.linear1.parametrizations.weight[0].scale
lora_term = (model.linear1.parametrizations.weight[0].lora_B @ model.linear1.parametrizations.weight[0].lora_A) * scale_term
new_weights = model.linear1.parametrizations.weight.original + lora_term

assert torch.equal(old_weights, new_weights)

In [62]:
# If we disable lora, linear1.weight should be the original

enable_disable_lora(enabled=False)
assert torch.equal(model.linear1.weight, original_weights['linear1.weight'])

In [63]:
# Test model with LoRA enabled (weight access uses parametrization)
# If everything worked correctly, performance on digit 4 should increase

enable_disable_lora(enabled=True)
test(model, test_loader)

Testing: 100%|██████████| 1000/1000 [00:04<00:00, 225.37it/s]


Accuracy: 0.853
Wrong counts for digit 0: 10
Wrong counts for digit 1: 44
Wrong counts for digit 2: 51
Wrong counts for digit 3: 97
Wrong counts for digit 4: 10
Wrong counts for digit 5: 64
Wrong counts for digit 6: 83
Wrong counts for digit 7: 131
Wrong counts for digit 8: 146
Wrong counts for digit 9: 833





In [64]:
# Test the model with LoRA disabled
# Should return to original performance

enable_disable_lora(enabled=False)
test(model, test_loader)

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 294.28it/s]


Accuracy: 0.956
Wrong counts for digit 0: 9
Wrong counts for digit 1: 7
Wrong counts for digit 2: 32
Wrong counts for digit 3: 55
Wrong counts for digit 4: 74
Wrong counts for digit 5: 38
Wrong counts for digit 6: 53
Wrong counts for digit 7: 63
Wrong counts for digit 8: 62
Wrong counts for digit 9: 49



