In [None]:
# Import Libraries
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm

_ = torch.manual_seed(0)

We will use a simple MNIST Network to demonstrate the LoRA implementation using PyTorch. The idea is that we will train an initial model for just one epoch , thereby creating a base model which do not work very well on certain digits (Analogy to a huge pre trained LLM). Then we will use LoRA to train the model on a specific digit (e.g. 0) and see how it performs on that digit. We will also see how the model performs on other digits after training with LoRA.

In [None]:
# Download MNIST Datasset

transform  = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,),(0.3081,))])

mnist_trainset = datasets.MNIST(root='./data', train=True,download=True,transform=transform)
train_loader = torch.utils.data.DataLoader(mnist_trainset,batch_size=10,shuffle=True)

mnist_testset = datasets.MNIST(root='./data',train=False,download = True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset,batch_size=10,shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
# Define the network
class RichBoyNet(nn.Module):
    def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
        super(RichBoyNet,self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1) 
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) 
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

net = RichBoyNet().to(device)



In [None]:
# train function
def train(train_loader,net,epochs=5,total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(),lr=0.001)

    for epoch in range(epochs):
        net.train()
        num_iterations = 0
        total_loss = 0.0
        total_samples = 0
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        for x, y in pbar:
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1,28*28))
            loss = cross_el(output,y)
            batch_size = x.size(0)
            total_loss += loss.item() * batch_size 
            total_samples += batch_size
            loss.backward()
            optimizer.step()
            avg_loss = total_loss / total_samples
            pbar.set_postfix({'avg_loss': f'{avg_loss:.4f}'})
            num_iterations +=1
            if num_iterations == total_iterations_limit : break

train(train_loader,net,epochs=1)





Epoch 1: 100%|██████████| 6000/6000 [03:52<00:00, 25.85it/s, avg_loss=0.2416]


In [None]:
# Save a copy of original weights for comparison 
original_weights = {}
for name,param in net.named_parameters():
    original_weights[name] = param.clone().detach

In [None]:
# Define test function
def test():
    correct = 0
    total = 0

    wrong_counts = [0 for i in range(10)]
    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = net(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                else:
                    wrong_counts[y[idx]] +=1
                total +=1
    print(f'Accuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_counts)):
        print(f'wrong counts for the digit {i}: {wrong_counts[i]}')

test()

Testing: 100%|██████████| 1000/1000 [00:02<00:00, 350.56it/s]

Accuracy: 0.954
wrong counts for the digit 0: 25
wrong counts for the digit 1: 16
wrong counts for the digit 2: 107
wrong counts for the digit 3: 38
wrong counts for the digit 4: 41
wrong counts for the digit 5: 35
wrong counts for the digit 6: 19
wrong counts for the digit 7: 62
wrong counts for the digit 8: 84
wrong counts for the digit 9: 31





As shown, the performance of MNIST network after epoch 1 is not great on digit 2. So we will use LoRA to train the model on digit 2 and see how it performs on that digit.

In [None]:
# Find the total number of trainable parameters in the network
total_parameters_original = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')
print(f'Total number of parameters: {total_parameters_original:,}')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10])
Total number of parameters: 2,807,010


We intoduce the LoRA layers into the PyTorch model using Parametrization technique. Refer to pytorch tutorial for more details : https://pytorch.org/tutorials/intermediate/parametrizations.html

In [None]:
class LoRAParametrization(nn.Module):
    def __init__(self,features_in,features_out,rank,alpha=1,device='cpu'):
        super().__init__()
        self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in,rank)).to(device))
        # Authors recommend to use a random Gaussian initialization for A and
        #zero for B, so ∆W = BA is zero at the beginning of training.
        nn.init.normal_(self.lora_A,mean=0,std=1)
        # When optimizing with Adam, tuning α is roughly the same as tuning the learning
        # rate if we scale the initialization appropriately. As a result, we simply set α to the first r we try
        # and do not tune it. This scaling helps to reduce the need to retune hyperparameters when we vary r
        self.scale = alpha / rank
        self.enabled = True
    
    def forward(self,original_weights):
        if self.enabled:
            return original_weights + torch.matmul(self.lora_B,self.lora_A).view(original_weights.shape)*self.scale
        else:
            return original_weights

In [None]:
# Transform linear layer to parameterized layer
import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer,device,rank=1,lora_alpha=1):
    features_in, features_out = layer.weight.shape
    return LoRAParametrization(features_in,features_out,rank=rank,alpha=lora_alpha,device=device)

parametrize.register_parametrization(net.linear1,"weight",linear_layer_parameterization(net.linear1,device))
parametrize.register_parametrization(net.linear2, "weight", linear_layer_parameterization(net.linear2, device))
parametrize.register_parametrization(net.linear3, "weight", linear_layer_parameterization(net.linear3, device))

# Function to enebale and disable lora layers
def enable_disable_lora(enabled=True):
    for layer in [net.linear1,net.linear2,net.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled

In [None]:
# Compare the total number of trainable parameters
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
    )
# The non-LoRA parameters count must match the original network
assert total_parameters_non_lora == total_parameters_original
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000]) + Lora_A: torch.Size([1, 784]) + Lora_B: torch.Size([1000, 1])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000]) + Lora_A: torch.Size([1, 1000]) + Lora_B: torch.Size([2000, 1])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10]) + Lora_A: torch.Size([1, 2000]) + Lora_B: torch.Size([10, 1])
Total number of parameters (original): 2,807,010
Total number of parameters (original + LoRA): 2,813,804
Parameters introduced by LoRA: 6,794
Parameters incremment: 0.242%


In [None]:
# Check the network , you can see the fused parametrized layers
net

RichBoyNet(
  (linear1): ParametrizedLinear(
    in_features=784, out_features=1000, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParametrization()
      )
    )
  )
  (linear2): ParametrizedLinear(
    in_features=1000, out_features=2000, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParametrization()
      )
    )
  )
  (linear3): ParametrizedLinear(
    in_features=2000, out_features=10, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParametrization()
      )
    )
  )
  (relu): ReLU()
)

In [None]:
# Freeze the original weights
for name,param in net.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

# Load the MNIST dataset again, by keeping only the digit 9
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 2
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Train the network with LoRA only on the digit 9 and only for 100 batches (hoping that it would improve the performance on the digit 9)
train(train_loader, net, epochs=1)

Freezing non-LoRA parameter linear1.bias
Freezing non-LoRA parameter linear1.parametrizations.weight.original
Freezing non-LoRA parameter linear2.bias
Freezing non-LoRA parameter linear2.parametrizations.weight.original
Freezing non-LoRA parameter linear3.bias
Freezing non-LoRA parameter linear3.parametrizations.weight.original


Epoch 1: 100%|██████████| 596/596 [00:05<00:00, 107.84it/s, avg_loss=0.0551]


In [23]:
# Test with LoRA enabled
enable_disable_lora(enabled=True)
test()

Testing: 100%|██████████| 1000/1000 [00:05<00:00, 195.07it/s]

Accuracy: 0.366
wrong counts for the digit 0: 738
wrong counts for the digit 1: 951
wrong counts for the digit 2: 2
wrong counts for the digit 3: 705
wrong counts for the digit 4: 869
wrong counts for the digit 5: 772
wrong counts for the digit 6: 380
wrong counts for the digit 7: 934
wrong counts for the digit 8: 866
wrong counts for the digit 9: 127





In [None]:
# Test with LoRA disabled
enable_disable_lora(enabled=False)
test()

Testing: 100%|██████████| 1000/1000 [00:02<00:00, 434.79it/s]

Accuracy: 0.954
wrong counts for the digit 0: 25
wrong counts for the digit 1: 16
wrong counts for the digit 2: 107
wrong counts for the digit 3: 38
wrong counts for the digit 4: 41
wrong counts for the digit 5: 35
wrong counts for the digit 6: 19
wrong counts for the digit 7: 62
wrong counts for the digit 8: 84
wrong counts for the digit 9: 31



