### "**LoRA Implementation**"

In [14]:
import torch 
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch.nn.utils.parametrize as parametrize

_ = torch.manual_seed(2023)

In [5]:
# prepare the dataset:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# load the mnist dataset:
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_data_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)
test_data_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# define the device:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
# define the model to classify the mnist dataset:
class DemoModel(nn.Module):
    def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
        super(DemoModel, self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

# initialize the model:
model = DemoModel().to(device)

In [7]:
# train the network for 1 epoch:
def train(train_data_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    
    total_iterations = 0
    
    for epoch in range(epochs):
        net.train()
        
        loss_sum = 0
        num_iterations = 0
        
        data_iterator = tqdm(train_data_loader, desc='Epoch {}'.format(epoch+1))
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28)) 
            loss = cross_el(output, y)  
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()
            
            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

# train the network for 1 epoch:
train(train_data_loader, model, epochs=1)

Epoch 1: 100%|██████████| 6000/6000 [00:30<00:00, 195.25it/s, loss=0.243]


In [9]:
# copy of the original model:
original_weights = {}
for name, param in model.named_parameters():
    original_weights[name] = param.clone().detach()
print(original_weights)

{'linear1.weight': tensor([[ 0.0201,  0.0409,  0.0572,  ...,  0.0585,  0.0407,  0.0263],
        [ 0.0055,  0.0478,  0.0310,  ...,  0.0062,  0.0027,  0.0230],
        [ 0.0319,  0.0066,  0.0430,  ...,  0.0170,  0.0270,  0.0062],
        ...,
        [ 0.0599,  0.0681,  0.0363,  ...,  0.0808,  0.0345,  0.0717],
        [ 0.0146,  0.0251,  0.0674,  ...,  0.0696,  0.0343,  0.0500],
        [-0.0124,  0.0121,  0.0030,  ..., -0.0288, -0.0194,  0.0263]],
       device='cuda:0'), 'linear1.bias': tensor([-3.7218e-02, -3.5061e-02,  6.7275e-03, -2.9251e-02, -5.8652e-02,
        -5.8553e-02, -3.6413e-02, -3.0102e-02, -4.4707e-02, -2.0931e-02,
        -5.1178e-02, -3.6080e-02, -4.1683e-02, -2.2179e-02,  2.5836e-03,
         3.1643e-03, -5.0738e-02,  7.7742e-03, -4.0873e-02, -4.8966e-02,
        -3.8418e-02, -4.4142e-02, -4.9995e-03, -9.9650e-03, -4.2807e-02,
        -5.8977e-02, -1.9777e-02, -1.5686e-02, -4.5655e-02, -1.4754e-02,
        -3.4274e-02, -6.9317e-03, -2.0472e-02, -4.2401e-03, -1.7719e

In [10]:
# test the performance of the trained model:
def test():
    correct = 0
    total = 0
    
    wrong_counts = [0 for i in range(10)]
    
    with torch.no_grad():
        for data in tqdm(test_data_loader, desc="Testing"):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 28*28))
            
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct += 1
                else:
                    wrong_counts[y[idx]] += 1
                total += 1
    
    print(f"Accuracy: {round(correct/total, 3)}")
    for i in range(len(wrong_counts)):
        print(f"Number {i} was wrong {wrong_counts[i]} times")
test()

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 292.36it/s]

Accuracy: 0.962
Number 0 was wrong 16 times
Number 1 was wrong 19 times
Number 2 was wrong 27 times
Number 3 was wrong 53 times
Number 4 was wrong 67 times
Number 5 was wrong 24 times
Number 6 was wrong 20 times
Number 7 was wrong 53 times
Number 8 was wrong 50 times
Number 9 was wrong 51 times





In [12]:
# inspecting the total number of parameters:
total_params_original = 0
for index, layer in enumerate([model.linear1, model.linear2, model.linear3]):
    total_params_original += layer.weight.numel() + layer.bias.numel()
    print(f"Layer {index+1} has {layer.weight.shape} weights and {layer.bias.shape} biases")
print(f"Total number of parameters in the original model: {total_params_original}")

Layer 1 has torch.Size([1000, 784]) weights and torch.Size([1000]) biases
Layer 2 has torch.Size([2000, 1000]) weights and torch.Size([2000]) biases
Layer 3 has torch.Size([10, 2000]) weights and torch.Size([10]) biases
Total number of parameters in the original model: 2807010


In [15]:
# defining the LoRA parametrization:https://github.com/hkproj/pytorch-lora/blob/main/lora.ipynb
class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device="cpu"):
        super().__init__()
        # Section 4.1 of the paper:
        # We use a random Gaussian initialization for A and zero for B, so ∆W = BA is zero at the beginning of training
        self.lora_A = nn.Parameter(torch.zeros((rank, features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)

        # Section 4.1 of the paper:
        # We then scale ∆Wx by α/r , where α is a constant in r.
        # When optimizing with Adam, tuning α is roughly the same as tuning the learning rate if we scale the initialization appropriately.
        # As a result, we simply set α to the first r we try and do not tune it.
        # This scaling helps to reduce the need to retune hyperparameters when we vary r.
        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            # Return W + (B*A)*scale
            return (
                original_weights
                + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape)
                * self.scale
            )
        else:
            return original_weights


# applying the LoRA parametrization to the model:
def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    # Only add the parameterization to the weight matrix, ignore the Bias
    # From section 4.2 of the paper:
    # We limit our study to only adapting the attention weights for downstream tasks and freeze the MLP modules
    # (so they are not trained in downstream tasks) both for simplicity and parameter-efficiency.
    # We leave the empirical investigation of [...], and biases to a future work.

    features_in, features_out = layer.weight.shape
    return LoRAParametrization(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )


parametrize.register_parametrization(
    model.linear1, "weight", linear_layer_parameterization(model.linear1, device)
)
parametrize.register_parametrization(
    model.linear2, "weight", linear_layer_parameterization(model.linear2, device)
)
parametrize.register_parametrization(
    model.linear3, "weight", linear_layer_parameterization(model.linear3, device)
)


# function to enable/disable the LoRA parametrization:
def enable_disable_lora(enabled=True):
    for layer in [model.linear1, model.linear2, model.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled

In [18]:
# inspecting the total number of parameters after applying the LoRA parametrization:
total_params_lora = 0
total_params_non_lora = 0
for index, layer in enumerate([model.linear1, model.linear2, model.linear3]):
    total_params_lora += (
        layer.parametrizations["weight"][0].lora_A.numel()
        + layer.parametrizations["weight"][0].lora_B.numel()
    )
    total_params_non_lora += layer.weight.numel() + layer.bias.numel()
    print(
        f"Layer {index+1} has {layer.parametrizations['weight'][0].lora_A.shape} LoRA parameters and {layer.weight.shape} non-LoRA parameters and {layer.bias.shape} biases and {layer.parametrizations['weight'][0].lora_B.shape} LoRA parameters"
    )

assert total_params_non_lora == total_params_original
print(f"Total number of parameters in original model: {total_params_non_lora}")
print(f"Total number of parameters in LoRA model: {total_params_lora}")
print(
    f"Total number of parameters in original + LoRA model: {total_params_non_lora + total_params_lora}"
)
parameters_increased = (total_params_lora / total_params_non_lora) * 100
print(f"Parameters increased by {round(parameters_increased, 3)}%")

Layer 1 has torch.Size([1, 784]) LoRA parameters and torch.Size([1000, 784]) non-LoRA parameters and torch.Size([1000]) biases and torch.Size([1000, 1]) LoRA parameters
Layer 2 has torch.Size([1, 1000]) LoRA parameters and torch.Size([2000, 1000]) non-LoRA parameters and torch.Size([2000]) biases and torch.Size([2000, 1]) LoRA parameters
Layer 3 has torch.Size([1, 2000]) LoRA parameters and torch.Size([10, 2000]) non-LoRA parameters and torch.Size([10]) biases and torch.Size([10, 1]) LoRA parameters
Total number of parameters in original model: 2807010
Total number of parameters in LoRA model: 6794
Total number of parameters in original + LoRA model: 2813804
Parameters increased by 0.242%


In [19]:
# freeze the non-LoRA parameters:
for name, param in model.named_parameters():
    if "lora" not in name:
        print(f"Freezing non-LoRA parameter {name}")
        param.requires_grad = False

# load the mnist dataset again keeping only the digit 4:
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
include_indices = mnist_trainset.targets == 4
mnist_trainset.data = mnist_trainset.data[include_indices]
mnist_trainset.targets = mnist_trainset.targets[include_indices]
train_data_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# train the network for 1 epoch for only 100 batches:
train(train_data_loader, model, epochs=1, total_iterations_limit=100)

Freezing non-LoRA parameter linear1.bias
Freezing non-LoRA parameter linear1.parametrizations.weight.original
Freezing non-LoRA parameter linear2.bias
Freezing non-LoRA parameter linear2.parametrizations.weight.original
Freezing non-LoRA parameter linear3.bias
Freezing non-LoRA parameter linear3.parametrizations.weight.original


Epoch 1:  99%|█████████▉| 99/100 [00:00<00:00, 173.50it/s, loss=0.121]


In [20]:
# assert that the frozen parameters are still unchanged by the finetuning
assert torch.all(
    model.linear1.parametrizations.weight.original == original_weights["linear1.weight"]
)
assert torch.all(
    model.linear2.parametrizations.weight.original == original_weights["linear2.weight"]
)
assert torch.all(
    model.linear3.parametrizations.weight.original == original_weights["linear3.weight"]
)

enable_disable_lora(enabled=True)
# The new linear1.weight is obtained by the "forward" function of our LoRA parametrization
# The original weights have been moved to net.linear1.parametrizations.weight.original
# More info here: https://pytorch.org/tutorials/intermediate/parametrizations.html#inspecting-a-parametrized-module
assert torch.equal(
    model.linear1.weight,
    model.linear1.parametrizations.weight.original
    + (
        model.linear1.parametrizations.weight[0].lora_B
        @ model.linear1.parametrizations.weight[0].lora_A
    )
    * model.linear1.parametrizations.weight[0].scale,
)

enable_disable_lora(enabled=False)
assert torch.equal(model.linear1.weight, original_weights["linear1.weight"])

In [21]:
# test when the LoRA parametrization is enabled:
enable_disable_lora(enabled=True)
test()

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 266.21it/s]

Accuracy: 0.949
Number 0 was wrong 17 times
Number 1 was wrong 22 times
Number 2 was wrong 37 times
Number 3 was wrong 67 times
Number 4 was wrong 16 times
Number 5 was wrong 31 times
Number 6 was wrong 20 times
Number 7 was wrong 58 times
Number 8 was wrong 73 times
Number 9 was wrong 174 times





In [22]:
# test when the LoRA parametrization is disabled:
enable_disable_lora(enabled=False)
test()

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 287.23it/s]

Accuracy: 0.962
Number 0 was wrong 16 times
Number 1 was wrong 19 times
Number 2 was wrong 27 times
Number 3 was wrong 53 times
Number 4 was wrong 67 times
Number 5 was wrong 24 times
Number 6 was wrong 20 times
Number 7 was wrong 53 times
Number 8 was wrong 50 times
Number 9 was wrong 51 times



