### Singular Value Decomposition

In [1]:
import torch
import numpy as np
_ = torch.manual_seed(0)

In [3]:
# create a 10x10 matrix with rank  = 2

d, k = 10, 10
w_rank = 2

W = torch.randn(d, w_rank) @ torch.randn(w_rank, k)
print(W)


tensor([[-1.7535e-01, -4.1425e-01, -1.6431e-01,  5.8721e-01, -7.5694e-02,
          9.2418e-01, -1.0040e-01,  1.7763e-01,  2.7832e-01, -3.5177e-01],
        [ 5.0407e-02,  8.5691e-01, -5.6317e-01, -4.2222e-01, -2.0700e+00,
         -2.0448e+00, -1.1324e-01,  9.9658e-02, -1.4092e-01,  1.1584e+00],
        [ 3.1282e-01,  8.3731e-03,  8.9758e-01, -7.9662e-01,  2.2064e+00,
          1.1310e-01,  3.1982e-01, -4.6613e-01, -4.3621e-01, -4.1948e-01],
        [ 2.6040e-01,  1.3233e+00, -3.4182e-01, -1.1152e+00, -1.8951e+00,
         -3.0799e+00,  1.2716e-02, -1.1913e-01, -4.7177e-01,  1.5371e+00],
        [ 8.6263e-01,  9.9178e-02,  2.4122e+00, -2.2229e+00,  5.8687e+00,
          1.2840e-01,  8.6730e-01, -1.2699e+00, -1.2092e+00, -1.0477e+00],
        [ 1.3888e-01, -1.7935e-01,  5.4995e-01, -2.9080e-01,  1.4986e+00,
          4.9165e-01,  1.7725e-01, -2.4434e-01, -1.7855e-01, -4.4858e-01],
        [-3.0799e-03, -1.3581e+00,  1.1147e+00,  4.7428e-01,  3.8284e+00,
          3.2736e+00,  2.5840e-0

In [4]:
# evaluate the rank of the matrix
w_rank = np.linalg.matrix_rank(W)
print("Rank of W : ", w_rank)

Rank of W :  2


### Calculate the SVD of the W matrix

In [5]:
U, S, V = torch.svd(W)

# for the rank factorization keep only the first 3 singular vlaues (and corresponding columns of U and V)
U_r = U[:, :w_rank]
S_r = torch.diag(S[:w_rank])
V_r = V[:, :w_rank].t()

# compute c = U_r, * S_r and R = V_r
B = U_r @ S_r
A = V_r

print("shape of B : ", B.shape)
print("shape of A : ", A.shape)

shape of B :  torch.Size([10, 2])
shape of A :  torch.Size([2, 10])


In [7]:
# generate random bias and input
bias = torch.randn(d)
x = torch.randn(d)

# compute y = Wx + b
y = W @ x + bias

# compute y' = CRx + b
y_prime = (B @ A) @ x + bias

print("Original y using w : \n", y)
print("")
print("y computed using SVD : \n", y_prime)

Original y using w : 
 tensor([ 0.0391,  0.8319, -1.9526,  1.0540, -6.0926, -0.7195, -2.7809,  0.6213,
        -3.8685, -1.9149])

y computed using SVD : 
 tensor([ 0.0391,  0.8319, -1.9526,  1.0540, -6.0926, -0.7195, -2.7809,  0.6213,
        -3.8685, -1.9149])


In [8]:
print("Total number of parameters of W : ", W.nelement())
print("total number of parameters for B and A : ", B.nelement() + A.nelement())

Total number of parameters of W :  100
total number of parameters for B and A :  40


# LoRA implementation using PyTorch

In [13]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm

#### Make the model deterministic

In [14]:
_ = torch.manual_seed(0)

#### Network to classify MNIST data

In [17]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))])

# load the MNIST data
mnist_trainset = datasets.MNIST(root = "./data", train=True, download=True, transform=transform)
# create a dataloade for training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)


# load the mnist test set
mnist_testset = datasets.MNIST(root = "./data", train = False, download=True, transform = transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# Define the device
device = torch.device("mps")

In [24]:
# Neural network to classify the digits

class Net(nn.Module):
    def __init__(self, hidden_size1 = 1000, hidden_size2 = 2000):
        super(Net, self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size1)
        self.linear2 = nn.Linear(hidden_size1, hidden_size2)
        self.linear3 = nn.Linear(hidden_size2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x
    
net = Net().to(device)


In [25]:
# train the network only for 1 epoch to similate a complete general pre training on the data
def train(train_loader, net, epochs=5, total_iterations_limit = None):
    cross_e1 = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0
    
    # define the trianing loop
    for epoch in range(epochs):
        # set the model in training mode
        net.train()
        # track the total loss
        loss_sum = 0
        # track the number of iterations for the current epoch
        num_terations = 0

        # init the data iterator 
        data_iterator = tqdm(train_loader, desc = f"Epoch {epoch+1}")

        # set the total_iteration_limit to the data_iterator
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        
        # iterate through the batch in the data iterator
        for data in data_iterator:
            # track the stats
            num_terations += 1
            total_iterations += 1

            # retrive the input and output from the data iterator
            x, y = data
            
            # assing the tensors to the gpu
            x = x.to(device)
            y = y.to(device)    

            # reset the optimizer
            optimizer.zero_grad()
            
            # get the predictions from the neural network
            output = net(x.view(-1, 28*28))

            # calculate the loss
            loss = cross_e1(output, y)

            loss_sum += loss.item()

            # track stats
            avg_loss = loss_sum / num_terations
            data_iterator.set_postfix(loss=avg_loss)

            # backward pass
            loss.backward()

            # update params
            optimizer.step()

            if total_iterations_limit is not None and total_iterations_limit >= total_iterations:
                return
            
train(train_loader, net, epochs=1)


        



Epoch 1: 100%|██████████| 6000/6000 [00:55<00:00, 108.81it/s, loss=0.242]


### keep a copy of the original weigts so that later we can prove that fine tuning with lora doesn't alter the original weights


In [28]:
original_weight = {}
for name, param in net.named_parameters():
    original_weight[name] = param.clone().detach()

### The performance of the pretrained network 

In [31]:
def test():
    correct = 0
    total = 0

    wrong_count = [0 for i in range(10)]

    with torch.no_grad():
        for data in tqdm(test_loader, desc = "Testing"):
            x, y = data
            x = x.to(device)
            y = y.to(device)

            output = net(x.view(-1, 28*28))

            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct += 1
                else:
                    wrong_count[y[idx]] += 1
                total += 1
    print(f"Accuracy : {round(correct / total, 3)}")
    for i in range(len(wrong_count)):
        print(f"wrong counts for the digit {i} : {wrong_count[i]}")
test()


Testing: 100%|██████████| 1000/1000 [00:06<00:00, 151.06it/s]

Accuracy : 0.961
wrong counts fopr the digit 0 : 28
wrong counts fopr the digit 1 : 16
wrong counts fopr the digit 2 : 88
wrong counts fopr the digit 3 : 37
wrong counts fopr the digit 4 : 39
wrong counts fopr the digit 5 : 27
wrong counts fopr the digit 6 : 18
wrong counts fopr the digit 7 : 50
wrong counts fopr the digit 8 : 50
wrong counts fopr the digit 9 : 35





### Visualizing the number of parameters 

In [33]:
total_params_original = 0

for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_params_original += layer.weight.nelement() + layer.bias.nelement()
    print(f"Layer {index+1}: W {layer.weight.shape} + B {layer.bias.shape}")
print(f"Total number of parameters : {total_params_original:,}")

Layer 1: W torch.Size([1000, 784]) + B torch.Size([1000])
Layer 2: W torch.Size([2000, 1000]) + B torch.Size([2000])
Layer 3: W torch.Size([10, 2000]) + B torch.Size([10])
Total number of parameters : 2,807,010


# Introducing LoRA

In [40]:
class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device="mps"):
        super().__init__()

        self.lora_a = nn.Parameter(torch.zeros((rank, features_out)).to(device))
        self.lora_b = nn.Parameter(torch.zeros((features_in, rank)).to(device))


        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            # return W + (B*A) * scale
            return original_weights + torch.matmul(self.lora_b, self.lora_a).view(original_weights.shape) * self.scale
        
        else:
            return original_weights

### Add the parameters

In [47]:
import torch.nn.utils.parametrize as parameterize


def linear_layer_parametrization(layer, device, rank=1, lora_alpha=1):
    features_in, features_out = layer.weight.shape

    return LoRAParametrization(
        features_in, features_out, rank, lora_alpha, device
    )

parameterize.register_parametrization(
    net.linear1, "weight", linear_layer_parametrization(net.linear1, device)
)

parameterize.register_parametrization(
    net.linear2, "weight", linear_layer_parametrization(net.linear2, device)
)

parameterize.register_parametrization(
    net.linear3, "weight", linear_layer_parametrization(net.linear3, device)
)

def enable_disable_lora(enabled=True):
    for layer in [net.linear1, net.linear2, net.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled

In [43]:
total_params_lora = 0
total_params_non_lora = 0

for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_params_lora += layer.parametrizations["weight"][0].lora_a.nelement() + layer.parametrizations["weight"][0].lora_b.nelement()
    total_params_non_lora += layer.weight.nelement() + layer.bias.nelement()

    print(
        f"Layer {index + 1} : W - {layer.weight.shape} + B - {layer.bias.shape} + Lora_A {layer.parametrizations['weight'][0].lora_a.shape} + Lora_B {layer.parametrizations['weight'][0].lora_b.shape}"
    )

assert total_params_non_lora == total_params_original
print(f"Total params (original) : {total_params_non_lora:,}")
print(f"Total params (original + LoRA) : {total_params_non_lora + total_params_lora:,}")
print(f"Params introduced by LoRA : {total_params_lora:,}")
params_inc = (total_params_lora / total_params_non_lora) * 100
print(f"Prams increment : {params_inc:.3f}%")

Layer 1 : W - torch.Size([1000, 784]) + B - torch.Size([1000]) + Lora_A torch.Size([1, 784]) + Lora_B torch.Size([1000, 1])
Layer 2 : W - torch.Size([2000, 1000]) + B - torch.Size([2000]) + Lora_A torch.Size([1, 1000]) + Lora_B torch.Size([2000, 1])
Layer 3 : W - torch.Size([10, 2000]) + B - torch.Size([10]) + Lora_A torch.Size([1, 2000]) + Lora_B torch.Size([10, 1])
Total params (original) : 2,807,010
Total params (original + LoRA) : 2,813,804
Params introduced by LoRA : 6,794
Prams increment : 0.242%


### Freeze all the params of the original network

In [46]:
for name, param in net.named_parameters():
    if "lora" not in name:
        print(f"Freezing non-lora parameter : {name}")
        param.requires_grad = False

# keeping only the digit 2
mnist_trainset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 2
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]

# create a new dataloader for the tranining
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)


train(train_loader, net, epochs=1, total_iterations_limit=100)

Freezing non-lora parameter : linear1.bias
Freezing non-lora parameter : linear1.parametrizations.weight.original
Freezing non-lora parameter : linear2.bias
Freezing non-lora parameter : linear2.parametrizations.weight.original
Freezing non-lora parameter : linear3.bias
Freezing non-lora parameter : linear3.parametrizations.weight.original


Epoch 1:   0%|          | 0/596 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/100 [00:00<?, ?it/s, loss=0.371]


### Verifying that the finetuning didnt affect the original weights

In [48]:
assert torch.all(net.linear1.parametrizations.weight.original == original_weight["linear1.weight"])
assert torch.all(net.linear2.parametrizations.weight.original == original_weight["linear2.weight"])
assert torch.all(net.linear3.parametrizations.weight.original == original_weight["linear3.weight"])

enable_disable_lora(enabled=True)

assert torch.equal(net.linear1.weight, net.linear1.parametrizations.weight.original + net.linear1.parametrizations.weight[0].lora_b @ net.linear1.parametrizations.weight[0].lora_a)

enable_disable_lora(enabled=False)

assert torch.equal(net.linear1.weight, original_weight["linear1.weight"])

### Test the network with Lora on the new dataset

In [49]:
enable_disable_lora(enabled=True)
test()

Testing: 100%|██████████| 1000/1000 [00:11<00:00, 85.70it/s]

Accuracy : 0.961
wrong counts fopr the digit 0 : 28
wrong counts fopr the digit 1 : 16
wrong counts fopr the digit 2 : 88
wrong counts fopr the digit 3 : 37
wrong counts fopr the digit 4 : 39
wrong counts fopr the digit 5 : 27
wrong counts fopr the digit 6 : 18
wrong counts fopr the digit 7 : 50
wrong counts fopr the digit 8 : 50
wrong counts fopr the digit 9 : 35



