<a href="https://colab.research.google.com/github/Lmalviya/machineTranslationTask/blob/main/Implement_LoRA_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Scop of this notebook
1. implement LoRA from scratch using Pytorch
2. Train for MNIST digit classification model using linear layer
3. Fine tune above model with LoRA

In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
_ = torch.manual_seed(0)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [3]:
config = {
    'batch_size': 10,
    'shuffle': True,
    'hiddenOne': 1000,
    'hiddenTwo': 2000,
    'epochs': 1,
    'lr': 1e-03
}


### Dataset

In [4]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3031))])

#load MNIST dataset
mnist_trainset = datasets.MNIST(root='/data', train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(root='/data', train=False, download=True, transform=transform)

#create dataloader
train_loader = DataLoader(mnist_trainset, batch_size=config['batch_size'], shuffle=config['shuffle'])
test_loader = DataLoader(mnist_test, batch_size=config['batch_size'], shuffle=config['shuffle'])



In [5]:
# Create an overly expensive neural network to classify MNIST digits
class RichBoyNet(nn.Module):
  def __init__(self, hiddenSizeOne, hiddenSizeTwo):
    super(RichBoyNet, self).__init__()
    self.layerOne = nn.Linear(28*28, hiddenSizeOne)
    self.layerTwo = nn.Linear(hiddenSizeOne, hiddenSizeTwo)
    self.layerThree = nn.Linear(hiddenSizeTwo, 10)
    self.relu = nn.ReLU()

  def forward(self, img):
    x = img.view(-1, 28*28)
    x = self.relu(self.layerOne(x))
    x = self.relu(self.layerTwo(x))
    x = self.layerThree(x)
    return x

netModel = RichBoyNet(config['hiddenOne'], config['hiddenTwo']).to(device)


### Train Loop

In [6]:
def oneStep(model, batch_iterator, optimizer, loss_fn, epoch, total_iteration, total_iteration_limit=None):
  model.train()
  loss_sum = 0
  num_iterations = 0
  for data in batch_iterator:
    num_iterations += 1
    total_iteration += 1
    x, y = data
    x = x.to(device)
    y = y.to(device)
    optimizer.zero_grad()
    output = model(x.view(-1, 28*28))
    loss = loss_fn(output, y)
    loss_sum += loss.item()
    avg_loss = loss_sum/num_iterations
    batch_iterator.set_postfix(loss=avg_loss)
    loss.backward()
    optimizer.step()

    if total_iteration_limit is not None and total_iteration >= total_iteration_limit:
      return False
  return True


def fit(model, train_loader, config, total_iteration_limit=None):
  cross_en = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
  total_iteration = 0

  for epoch in range(config['epochs']):
    batch_iterator = tqdm(train_loader, desc=f"Epoch {epoch+1}")

    if total_iteration_limit is not None:
      batch_iterator.total = total_iteration_limit

    isContinue = oneStep(model, batch_iterator, optimizer, cross_en, epoch, total_iteration, total_iteration_limit)
    if not isContinue:
      return

  return


In [7]:
fit(netModel, train_loader, config)

Epoch 1: 100%|██████████| 6000/6000 [02:54<00:00, 34.31it/s, loss=0.237]


### Keep copy of the original weights so later compaure that using LoRA base model weights are not changed

In [8]:
original_weights = {}
total_parameters_original  = 0
for name, param in netModel.named_parameters():
  original_weights[name] = param.clone().detach()


### Count Number of parameters present in the model

In [9]:
total_parameters_original = 0
for index, layer in enumerate([netModel.layerOne, netModel.layerTwo, netModel.layerThree]):
  total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
  print(f"Layer {index+1}: W: {layer.weight.shape}, B: {layer.bias.shape}")

print(f"Total Number of parameters in original Model: {total_parameters_original}")

Layer 1: W: torch.Size([1000, 784]), B: torch.Size([1000])
Layer 2: W: torch.Size([2000, 1000]), B: torch.Size([2000])
Layer 3: W: torch.Size([10, 2000]), B: torch.Size([10])
Total Number of parameters in original Model: 2807010


### Test the model performance

In [10]:
def test(model, test_loader):
  correct = 0
  total = 0

  wrong_counts = [0 for i in range(10)]

  with torch.no_grad():
    for data in tqdm(test_loader, desc='Testing'):
      x, y = data
      x = x.to(device)
      y = y.to(device)
      output = model(x.view(-1, 28*28))
      for idx, i in enumerate(output):
        total += 1
        if torch.argmax(i) == y[idx]:
          correct += 1
        else:
          wrong_counts[y[idx]] += 1
  return correct, total, wrong_counts

correct, total, wrong_counts = test(netModel, test_loader)
print(f"\n\nAccuracy: {round(correct/total, 3)}")
for i in range(len(wrong_counts)):
  print(f"Wrong counts for the digits {i}: {wrong_counts[i]}")

Testing: 100%|██████████| 1000/1000 [00:04<00:00, 204.97it/s]



Accuracy: 0.947
Wrong counts for the digits 0: 12
Wrong counts for the digits 1: 33
Wrong counts for the digits 2: 37
Wrong counts for the digits 3: 112
Wrong counts for the digits 4: 10
Wrong counts for the digits 5: 20
Wrong counts for the digits 6: 62
Wrong counts for the digits 7: 59
Wrong counts for the digits 8: 18
Wrong counts for the digits 9: 166





# Define LoRA parameters

In [11]:
import torch.nn.utils.parametrize as parametrize

In [40]:
class LoRAParametrization(nn.Module):
  def __init__(self, f_in, f_out, rank=1, alpha=1, device='cpu'):
    super(LoRAParametrization, self).__init__()
    self.lora_A = nn.Parameter(torch.zeros((rank, f_out))).to(device)
    self.lora_B = nn.Parameter(torch.zeros((f_in, rank))).to(device)

    nn.init.normal_(self.lora_A, mean=0, std=1)
    self.scaler = alpha/rank
    self.enabled = True

  def forward(self, original_weights):
    if self.enabled:
      # Return X + (B*A)*scaler
      return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scaler
    else:
      return original_weights


### Replace the linear layer weights with LoRA parameters

In [42]:
def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
  # only add the parameterization to the weight matrix, ignore the bias

  features_in, features_out = layer.weight.shape
  return LoRAParametrization(features_in, features_out, rank=rank, alpha=lora_alpha, device=device)

# working parametrize function
# 1. it will replace the weight matrix of any layer with the given class

parametrize.register_parametrization(
    netModel.layerOne, "weight", linear_layer_parameterization(netModel.layerOne, device)
)
parametrize.register_parametrization(
    netModel.layerTwo, "weight", linear_layer_parameterization(netModel.layerTwo, device)
)
parametrize.register_parametrization(
    netModel.layerThree, "weight", linear_layer_parameterization(netModel.layerThree, device)
)


ParametrizedLinear(
  in_features=2000, out_features=10, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0-2): 3 x LoRAParametrization()
    )
  )
)

In [43]:
def enable_disable_lora(enabled=True):
  for layer in [netModel.layerOne, netModel.layerTwo, netModel.layerThree]:
    layer.parameterization['weight'][0].enabled = True

### Display the number of parameters added by LoRA

In [44]:
total_parameters_lora = 0
total_parameters_non_lora = 0

for index, layer in enumerate([netModel.layerOne, netModel.layerTwo, netModel.layerThree]):
  total_parameters_lora += layer.parametrizations.weight[0].lora_A.nelement() + layer.parametrizations.weight[0].lora_B.nelement()
  total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()

  print(f"Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + lora_A: {layer.parametrizations.weight[0].lora_A.shape} + lora_B: {layer.parametrizations.weight[0].lora_B.shape}")


assert total_parameters_non_lora == total_parameters_original

print(f"Total number of parameters (original): {total_parameters_non_lora}")
print(f"Total number of parameters (original + LoRA): {total_parameters_lora+total_parameters_non_lora}")
print(f"Parameters introduced by LoRA: {total_parameters_lora}")
parameters_incremment = (total_parameters_lora/total_parameters_non_lora)*100
print(f"Parameters increment: {parameters_incremment:.3f}%")

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000]) + lora_A: torch.Size([1, 1000]) + lora_B: torch.Size([784, 1])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000]) + lora_A: torch.Size([1, 2000]) + lora_B: torch.Size([1000, 1])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10]) + lora_A: torch.Size([1, 10]) + lora_B: torch.Size([2000, 1])
Total number of parameters (original): 2807010
Total number of parameters (original + LoRA): 2813804
Parameters introduced by LoRA: 6794
Parameters increment: 0.242%


Feeze all original model parameters

In [45]:
for name, param in netModel.named_parameters():
  if 'lora' not in name:
    print(f"Freezing non-LoRA parameters {name}")
    param.required_grad = False

Freezing non-LoRA parameters layerOne.bias
Freezing non-LoRA parameters layerOne.parametrizations.weight.original
Freezing non-LoRA parameters layerTwo.bias
Freezing non-LoRA parameters layerTwo.parametrizations.weight.original
Freezing non-LoRA parameters layerThree.bias
Freezing non-LoRA parameters layerThree.parametrizations.weight.original


### Fine tune the model with LoRA on digit 9 dataset


In [46]:
mnist_trainset = datasets.MNIST(root='/data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]

train_loader_only_9 = DataLoader(mnist_trainset, batch_size=config['batch_size'], shuffle=config['shuffle'])

In [47]:
fit(netModel, train_loader_only_9, config, 100)

Epoch 1:  99%|█████████▉| 99/100 [00:05<00:00, 16.72it/s, loss=0]


In [48]:
# check that the frozen parameters are still unchanged by the fine tuning

assert torch.all(netModel.layerOne.parametrizations.weight.original == original_weights['layerOne.weight'])
assert torch.all(netModel.layerTwo.parametrizations.weight.original == original_weights['layerTwo.weight'])
assert torch.all(netModel.layerThree.parametrizations.weight.original == original_weights['layerThree.weight'])


# enable_disable_lora(enabled=True)
# assert torch.equal(netModel.layerOne.weight, netModel.linearOne.parametrizations.weight.original+(netModel.layerOne.parametrizations.weight[0].lora_B @ netModel.layerOne.parametrizations.weight[0].lora_A))


# enable_disable_lora(enabled=False)
# assert torch.equal(netModel.layerOne.weight, original_weights['layerOne.weight'])


AssertionError: 

In [49]:
enable_disable_lora(enabled=False)

correct, total, wrong_counts = test(netModel, test_loader)
print(f"\n\nAccuracy: {round(correct/total, 3)}")
for i in range(len(wrong_counts)):
  print(f"Wrong counts for the digits {i}: {wrong_counts[i]}")

AttributeError: 'ParametrizedLinear' object has no attribute 'parameterization'

In [50]:
netModel.layerOne.parametrizations.weight.original

Parameter containing:
tensor([[-0.0050,  0.0145, -0.0341,  ...,  0.0172, -0.0010, -0.0026],
        [ 0.0081,  0.0129,  0.0174,  ...,  0.0076,  0.0219, -0.0021],
        [ 0.0295,  0.0646,  0.0164,  ...,  0.0293,  0.0508,  0.0577],
        ...,
        [-0.0005,  0.0631,  0.0613,  ...,  0.0303,  0.0564,  0.0039],
        [ 0.0624,  0.0257,  0.0114,  ...,  0.0457,  0.0393,  0.0401],
        [ 0.0078, -0.0159,  0.0353,  ...,  0.0447,  0.0181,  0.0285]],
       requires_grad=True)

In [51]:
original_weights['layerOne.weight']

tensor([[ 0.0041,  0.0235, -0.0251,  ...,  0.0263,  0.0081,  0.0064],
        [ 0.0081,  0.0129,  0.0174,  ...,  0.0076,  0.0219, -0.0021],
        [ 0.0295,  0.0646,  0.0164,  ...,  0.0293,  0.0508,  0.0577],
        ...,
        [-0.0005,  0.0631,  0.0613,  ...,  0.0303,  0.0564,  0.0039],
        [ 0.0625,  0.0257,  0.0115,  ...,  0.0458,  0.0394,  0.0402],
        [ 0.0078, -0.0159,  0.0353,  ...,  0.0447,  0.0181,  0.0285]])

In [52]:
netModel.layerOne.parametrizations.weight.original

Parameter containing:
tensor([[-0.0050,  0.0145, -0.0341,  ...,  0.0172, -0.0010, -0.0026],
        [ 0.0081,  0.0129,  0.0174,  ...,  0.0076,  0.0219, -0.0021],
        [ 0.0295,  0.0646,  0.0164,  ...,  0.0293,  0.0508,  0.0577],
        ...,
        [-0.0005,  0.0631,  0.0613,  ...,  0.0303,  0.0564,  0.0039],
        [ 0.0624,  0.0257,  0.0114,  ...,  0.0457,  0.0393,  0.0401],
        [ 0.0078, -0.0159,  0.0353,  ...,  0.0447,  0.0181,  0.0285]],
       requires_grad=True)