Today we are going to imlement LoRA from scratch

In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
import tqdm as tqdm

### make model deterministic

In [2]:
_ = torch.manual_seed(0)

In [3]:
transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# load eminist train data
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transforms)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=10, shuffle=True)

# load eminist test data
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transforms)

test_loader = torch.utils.data.DataLoader(test_data, batch_size=10, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 16.4MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 498kB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 4.51MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 3.88MB/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






create neural network to classify the digit

In [4]:
class RichBoyNet(nn.Module):
  def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
    super().__init__()
    self.linear1 = nn.Linear(28*28, hidden_size_1)
    self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
    self.linear3 = nn.Linear(hidden_size_2, 10)
    self.relu = nn.ReLU()

  def forward(self, x):
    x = x.view(x.size(0), -1)
    x = self.relu(self.linear1(x))
    x = self.relu(self.linear2(x))
    x = self.linear3(x)
    return x

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = RichBoyNet().to(device)

In [6]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
  cross_el = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

  total_iterations = 0

  for epoch in range(epochs):
    net.train()

    loss_sum = 0
    num_iterations = 0

    data_iterator = tqdm.tqdm(train_loader, desc=f"Epoch {epoch + 1}")
    if total_iterations_limit is not None:
      data_iterator.total = total_iterations_limit

    for img, label in data_iterator:
      total_iterations += 1
      num_iterations += 1

      img, label = img.to(device), label.to(device)

      optimizer.zero_grad()

      # forward pass
      outputs = net(img.view(-1, 28*28))
      loss = cross_el(outputs, label)

      loss_sum += loss

      loss.backward()
      optimizer.step()

      data_iterator.set_postfix(loss=loss.item())

      if total_iterations_limit is not None and total_iterations > total_iterations_limit:
        break


train(train_loader, net)


Epoch 1: 100%|██████████| 6000/6000 [00:51<00:00, 115.73it/s, loss=0.449]
Epoch 2: 100%|██████████| 6000/6000 [00:49<00:00, 120.82it/s, loss=0.00302]
Epoch 3: 100%|██████████| 6000/6000 [00:49<00:00, 120.64it/s, loss=0.29]
Epoch 4: 100%|██████████| 6000/6000 [00:51<00:00, 117.62it/s, loss=6.56e-5]
Epoch 5: 100%|██████████| 6000/6000 [00:50<00:00, 118.65it/s, loss=0.0127]


In [7]:
original_weights = {}
for name, param in net.named_parameters():
  original_weights[name] = param.clone().detach()

In [8]:
def test():
  correct = 0
  total = 0

  wrong_counts = [0 for i in range(10)]

  with torch.no_grad():
    for images, labels in tqdm.tqdm(test_loader, desc="Testing"):
      images, labels = images.to(device), labels.to(device)
      outputs = net(images.view(-1, 28*28))

      for idx, i in enumerate(outputs):
        if torch.argmax(i) == labels[idx]:
          correct += 1
        else:
          wrong_counts[labels[idx]] += 1
        total += 1
  print(f"Accuracy: {round(correct / total, 3)}")

  for i in range(len(wrong_counts)):
    print(f"Worng count of digits {i}: {wrong_counts[i]}")

In [9]:
test()

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 257.31it/s]

Accuracy: 0.97
Worng count of digits 0: 5
Worng count of digits 1: 7
Worng count of digits 2: 67
Worng count of digits 3: 44
Worng count of digits 4: 23
Worng count of digits 5: 29
Worng count of digits 6: 56
Worng count of digits 7: 13
Worng count of digits 8: 25
Worng count of digits 9: 31





Let's visualize how many parameters are in the original network, before including the LoRA matrices

In [10]:
total_parameters_orignal = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
  total_parameters_orignal += layer.weight.nelement() + layer.bias.nelement()
  print(f"Layer-{index + 1}  W: {layer.weight.nelement()}  B: {layer.bias.nelement()}")
print(f"Total parameters is {total_parameters_orignal}")

Layer-1  W: 784000  B: 1000
Layer-2  W: 2000000  B: 2000
Layer-3  W: 20000  B: 10
Total parameters is 2807010


In [11]:
### From Paper section 4.1
# https://arxiv.org/pdf/2106.09685


# We illustrate our reparametrization in Figure 1. We use a random Gaussian initialization for A and
# zero for B, so ∆W = BA is zero at the beginning of training. We then scale ∆W x by α
# r , where α
# is a constant in r. When optimizing with Adam, tuning α is roughly the same as tuning the learning
# rate if we scale the initialization appropriately. As a result, we simply set α to the first r we try
# and do not tune it. This scaling helps to reduce the need to retune hyperparameters when we vary
# r

class LoRAParamterization(nn.Module):
  def __init__(self, feature_in, feature_out, rank=1, alpha=1, device='cpu'):
    super().__init__()
    # we use random Gaussian Intialization for A and zero for B

    self.lora_A = nn.Parameter(torch.zeros((rank, feature_out)).to(device))
    self.lora_B = nn.Parameter(torch.zeros((feature_in, rank)).to(device))
    nn.init.normal_(self.lora_A, mean=0, std=1)

    self.scale = alpha / rank
    self.enabled = True


  def forward(self, original_weights):
    if self.enabled:
      return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
    else:
      return original_weights


In [12]:
import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):

  # Only add the  parameterization to the weight matrix, ignore the Bias

  features_in, features_out = layer.weight.shape
  return LoRAParamterization(
      features_in, features_out, rank=rank, alpha=lora_alpha, device=device
  )

parametrize.register_parametrization(
    net.linear1, "weight", linear_layer_parameterization(net.linear1, device)
)

parametrize.register_parametrization(
    net.linear2, "weight", linear_layer_parameterization(net.linear2, device)
)

parametrize.register_parametrization(
    net.linear3, "weight", linear_layer_parameterization(net.linear3, device)
)

In [19]:
def enable_disable_lora(enabled=True):
  for layer in [net.linear1, net.linear2, net.linear3]:
    layer.parametrizations["weight"][0].enabled = enabled

Display the number of parameters added by LoRA

In [13]:
total_parameters_lora = 0
total_parameters_non_lora = 0

for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
  total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
  total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
  print(f"layer-{index + 1} W: {layer.weight.shape} B: {layer.bias.shape} lora_A : {layer.parametrizations['weight'][0].lora_A.shape} lora_B: {layer.parametrizations['weight'][0].lora_B.shape}")


assert total_parameters_non_lora == total_parameters_orignal

print(f"Total number of parameters original: {total_parameters_non_lora}")
print(f"Total number of parameters original + lora: {total_parameters_non_lora + total_parameters_lora}")
print(f"Total number of parameters add by lora: {total_parameters_lora}")

layer-1 W: torch.Size([1000, 784]) B: torch.Size([1000]) lora_A : torch.Size([1, 784]) lora_B: torch.Size([1000, 1])
layer-2 W: torch.Size([2000, 1000]) B: torch.Size([2000]) lora_A : torch.Size([1, 1000]) lora_B: torch.Size([2000, 1])
layer-3 W: torch.Size([10, 2000]) B: torch.Size([10]) lora_A : torch.Size([1, 2000]) lora_B: torch.Size([10, 1])
Total number of parameters original: 2807010
Total number of parameters original + lora: 2813804
Total number of parameters add by lora: 6794


In [14]:
# Freeze the non-Lora parameters
for name, param in net.named_parameters():
  if 'lora' not in name:
    print(f"Freezing non-lora parameters {name}")
    param.b=False

Freezing non-lora parameters linear1.bias
Freezing non-lora parameters linear1.parametrizations.weight.original
Freezing non-lora parameters linear2.bias
Freezing non-lora parameters linear2.parametrizations.weight.original
Freezing non-lora parameters linear3.bias
Freezing non-lora parameters linear3.parametrizations.weight.original


In [15]:
train(train_loader, net)

Epoch 1: 100%|██████████| 6000/6000 [00:50<00:00, 117.95it/s, loss=0.00429]
Epoch 2: 100%|██████████| 6000/6000 [00:51<00:00, 117.32it/s, loss=9.82e-6]
Epoch 3: 100%|██████████| 6000/6000 [00:51<00:00, 115.63it/s, loss=1.28e-5]
Epoch 4: 100%|██████████| 6000/6000 [00:51<00:00, 117.48it/s, loss=0.0528]
Epoch 5: 100%|██████████| 6000/6000 [00:50<00:00, 117.94it/s, loss=1.25e-6]


In [20]:
# Test the LoRA enaled
enable_disable_lora(enabled=True)
test()

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 286.35it/s]

Accuracy: 0.981
Worng count of digits 0: 10
Worng count of digits 1: 10
Worng count of digits 2: 28
Worng count of digits 3: 20
Worng count of digits 4: 16
Worng count of digits 5: 19
Worng count of digits 6: 24
Worng count of digits 7: 22
Worng count of digits 8: 18
Worng count of digits 9: 27





In [21]:
# Test the lora disable result must be same as original
enable_disable_lora(enabled=False)
test()

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 305.27it/s]

Accuracy: 0.97
Worng count of digits 0: 5
Worng count of digits 1: 7
Worng count of digits 2: 67
Worng count of digits 3: 44
Worng count of digits 4: 23
Worng count of digits 5: 29
Worng count of digits 6: 56
Worng count of digits 7: 13
Worng count of digits 8: 25
Worng count of digits 9: 31





In [22]:
# Check that the frozen parameters are still unchanged by the finetuning

assert torch.all(net.linear1.parametrizations.weight.original == net.linear1.weight)
assert torch.all(net.linear2.parametrizations.weight.original == net.linear2.weight)
assert torch.all(net.linear3.parametrizations.weight.original == net.linear3.weight)