# Single Value Decomposition : Idea behind LoRA

In [9]:
import torch
import numpy as np
_ = torch.manual_seed(0)

In [10]:
d, k = 10, 10

# This matrix is rank-deficient 
W_rank = 2
W = torch.randn(d,W_rank) @ torch.randn(W_rank,k)
print(W)
# the matrix is 10 by 10 but the rank of this matrix is still 2

tensor([[-1.0797,  0.5545,  0.8058, -0.7140, -0.1518,  1.0773,  2.3690,  0.8486,
         -1.1825, -3.2632],
        [-0.3303,  0.2283,  0.4145, -0.1924, -0.0215,  0.3276,  0.7926,  0.2233,
         -0.3422, -0.9614],
        [-0.5256,  0.9864,  2.4447, -0.0290,  0.2305,  0.5000,  1.9831, -0.0311,
         -0.3369, -1.1376],
        [ 0.7900, -1.1336, -2.6746,  0.1988, -0.1982, -0.7634, -2.5763, -0.1696,
          0.6227,  1.9294],
        [ 0.1258,  0.1458,  0.5090,  0.1768,  0.1071, -0.1327, -0.0323, -0.2294,
          0.2079,  0.5128],
        [ 0.7697,  0.0050,  0.5725,  0.6870,  0.2783, -0.7818, -1.2253, -0.8533,
          0.9765,  2.5786],
        [ 1.4157, -0.7814, -1.2121,  0.9120,  0.1760, -1.4108, -3.1692, -1.0791,
          1.5325,  4.2447],
        [-0.0119,  0.6050,  1.7245,  0.2584,  0.2528, -0.0086,  0.7198, -0.3620,
          0.1865,  0.3410],
        [ 1.0485, -0.6394, -1.0715,  0.6485,  0.1046, -1.0427, -2.4174, -0.7615,
          1.1147,  3.1054],
        [ 0.9088,  

#### Evaluating the rank of matrix 

In [11]:
W_rank = np.linalg.matrix_rank(W)
print(f'Rank of matrix is {W_rank}')

Rank of matrix is 2


##### Calaculate SVD decompostion of this matrix 

In [17]:
# Perform SVD
U, S, V = torch.svd(W)

# For rank-r factorization, kepp only first r singular values

U_r = U[:,:W_rank]
S_r = torch.diag(S[:W_rank])
V_r = V[:,:W_rank].t()

print(f"S_r is {S_r.shape}")
print(f"V_r is {V_r.shape}")
print(f"U_r is {U_r.shape}")


B = U_r @ S_r
A = V_r

print(f'Shape of B : {B.shape}')
print(f'Shape of A : {A.shape}')

S_r is torch.Size([2, 2])
V_r is torch.Size([2, 10])
U_r is torch.Size([10, 2])
Shape of B : torch.Size([10, 2])
Shape of A : torch.Size([2, 10])


##### Given the sampe input, check the output using the original W  matrix and the matrics resulting from the decompsotion


In [18]:
bias = torch.randn(d)
x = torch.randn(d)

# y = Wx + b
y = W @ x + bias

# y' = CRx = b
y_prime = (B @ A)@ x + bias

print("Original y using W \n", y)
print()
print("Computed y using BA \n", y)


Original y using W 
 tensor([ 7.2684e+00,  2.3162e+00,  7.7151e+00, -1.0446e+01, -8.1639e-03,
        -3.7270e+00, -1.1146e+01,  2.0207e+00, -9.6258e+00, -4.1163e+00])

Computed y using BA 
 tensor([ 7.2684e+00,  2.3162e+00,  7.7151e+00, -1.0446e+01, -8.1639e-03,
        -3.7270e+00, -1.1146e+01,  2.0207e+00, -9.6258e+00, -4.1163e+00])


In [21]:
print(f"Total Params in W : {W.nelement()}")
print(f"Total Params in B and A : {B.nelement() + A.nelement()}")

Total Params in W : 100
Total Params in B and A : 40


# LoRA Impplementation

In [22]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt 
from tqdm import tqdm

In [23]:
_ = torch.manual_seed(0)

In [25]:
print(torch.backends.mps.is_available())

True


In [26]:
import torch
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:02<00:00, 3572544.77it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 5915695.36it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 1882655.42it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2545841.08it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [34]:
class RichBoyNet(nn.Module):

    def __init__(self, hidden_size_1 = 1000, hidden_size_2 = 2000):
        super(RichBoyNet,self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2,10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x
    
net = RichBoyNet().to(device)


In [40]:
def train(train_loader, ne, epochs=5, total_iterations_limit=None):
    cross_entropy = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch + 1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x,y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1,28*28))
            loss = cross_entropy(output,y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return
    
    print(x.shape)

train(train_loader,net, epochs=1)

Epoch 1: 100%|██████████| 6000/6000 [00:58<00:00, 102.02it/s, loss=0.125]

torch.Size([10, 1, 28, 28])





### Keep copy of the original weights

In [41]:
original_weights = {}
for name, param in net.named_parameters():
    original_weights[name] = param.clone().detach()
    

In [42]:
original_weights

{'linear1.weight': tensor([[ 0.0491,  0.0412,  0.0896,  ...,  0.1083,  0.0500,  0.1037],
         [ 0.0276,  0.0295,  0.0413,  ..., -0.0105, -0.0150, -0.0233],
         [ 0.0405, -0.0032,  0.0142,  ...,  0.0140, -0.0163,  0.0253],
         ...,
         [ 0.0337,  0.0420,  0.0775,  ...,  0.0680,  0.0738,  0.0745],
         [ 0.1268,  0.0771,  0.0724,  ...,  0.1025,  0.0667,  0.0914],
         [ 0.0077,  0.0371,  0.0322,  ...,  0.0069, -0.0023,  0.0207]],
        device='mps:0'),
 'linear1.bias': tensor([-4.9370e-02, -1.7618e-02,  7.2536e-03, -4.5131e-03, -5.8461e-02,
         -3.1394e-02, -9.1590e-02, -1.0883e-01, -1.0716e-01, -6.1761e-02,
         -7.0232e-02, -1.5700e-02, -4.8010e-03, -9.8721e-02, -1.0534e-01,
         -7.4751e-02, -5.3460e-02, -9.9413e-03, -1.2278e-02, -9.6957e-02,
         -8.8563e-02, -2.4805e-02, -5.2471e-02, -4.4481e-02, -6.3462e-02,
         -8.7155e-02, -9.0936e-02,  8.1895e-03, -4.9283e-02, -4.5271e-02,
         -4.2645e-02,  1.4025e-02, -2.6896e-02, -1.5687e

In [43]:
def test():
    correct = 0
    total  = 0

    wrong_counts = [0 for i in range(10)]

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x,y = data
            x = x.to(device)
            y = y.to(device)
            output = net(x.view(-1,784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct += 1
                else:
                    wrong_counts[y[idx]] += 1
                total += 1
    print(f'Accuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_counts)):
        print(f'wrong counts for the digit {i} : {wrong_counts[i]}')

    
test()

Testing: 100%|██████████| 1000/1000 [00:06<00:00, 164.37it/s]

Accuracy: 0.968
wrong counts for the digit 0 : 13
wrong counts for the digit 1 : 18
wrong counts for the digit 2 : 44
wrong counts for the digit 3 : 22
wrong counts for the digit 4 : 24
wrong counts for the digit 5 : 43
wrong counts for the digit 6 : 16
wrong counts for the digit 7 : 32
wrong counts for the digit 8 : 47
wrong counts for the digit 9 : 60





In [60]:
total_parameters_original = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    print(f'Layer {index + 1}: W {layer.weight.shape} + B: {layer.bias.shape}')
print(f'Total number of parameters: {total_parameters_original}')

Layer 1: W torch.Size([1000, 784]) + B: torch.Size([1000])
Layer 2: W torch.Size([2000, 1000]) + B: torch.Size([2000])
Layer 3: W torch.Size([10, 2000]) + B: torch.Size([10])
Total number of parameters: 2807010


In [50]:
class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
        super().__init__()

        self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in,rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0,std=1)


        self.scale = alpha / rank
        self.enabled = True


    def forward(self, original_weights):
        if self.enabled:
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

In [71]:
import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer, device, rank=1, lora_aplha=1):

    features_in, features_out = layer.weight.shape
    
    return LoRAParametrization(
        features_in, features_out, rank = rank, alpha=lora_aplha, device = device
    )

parametrize.register_parametrization(
    net.linear1, 'weight', linear_layer_parameterization(net.linear1, device)
)

parametrize.register_parametrization(
    net.linear2, 'weight', linear_layer_parameterization(net.linear2, device)
)
parametrize.register_parametrization(
    net.linear3, 'weight', linear_layer_parameterization(net.linear3, device)
)

def enable_disable_lora(enabled=True):
    for layer in [net.linear1,net.linear2,net.linear3]:
        layer.parametrizations['weight'][0].enabled = enabled

In [56]:
total_parameters_lora = 0 
total_parameters_non_lora = 0 
for index, layer in enumerate([net.linear1,net.linear2,net.linear3]):
    total_parameters_lora += layer.parametrizations['weight'][0].lora_A.nelement() + layer.parametrizations['weight'][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index +1 }: W: {layer.weight.shape} + B " {layer.bias.shape} + Lora_A : {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B : {layer.parametrizations["weight"][0].lora_B.shape}')
    

print(f'Total number of params (original): {total_parameters_non_lora}')
print(f'Total number of params (original + LoRA): {total_parameters_non_lora + total_parameters_lora}')
print(f'Params introduced by LoRA: {total_parameters_lora}')
params_increment = (total_parameters_lora/ total_parameters_non_lora) * 100
print(f'Params Increment :  {params_increment:.3f}')


Layer 1: W: torch.Size([1000, 784]) + B " torch.Size([1000]) + Lora_A : torch.Size([1, 784]) + Lora_B : torch.Size([1000, 1])
Layer 2: W: torch.Size([2000, 1000]) + B " torch.Size([2000]) + Lora_A : torch.Size([1, 1000]) + Lora_B : torch.Size([2000, 1])
Layer 3: W: torch.Size([10, 2000]) + B " torch.Size([10]) + Lora_A : torch.Size([1, 2000]) + Lora_B : torch.Size([10, 1])
Total number of params (original): 2807010
Total number of params (original + LoRA): 2813804
Params introduced by LoRA: 6794
Params Increment :  0.242


In [59]:
[i[0] for i in net.named_parameters()]

['linear1.bias',
 'linear1.parametrizations.weight.original',
 'linear1.parametrizations.weight.0.lora_A',
 'linear1.parametrizations.weight.0.lora_B',
 'linear2.bias',
 'linear2.parametrizations.weight.original',
 'linear2.parametrizations.weight.0.lora_A',
 'linear2.parametrizations.weight.0.lora_B',
 'linear3.bias',
 'linear3.parametrizations.weight.original',
 'linear3.parametrizations.weight.0.lora_A',
 'linear3.parametrizations.weight.0.lora_B']

### Freezing the Non-Lora params

In [65]:
for name, param in net.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

mnist_trainset = datasets.MNIST(root='./data', train=True,download=True,transform = transform)
exclude_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]

train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

train(train_loader,net,epochs=1, total_iterations_limit=1000)

Freezing non-LoRA parameter linear1.bias
Freezing non-LoRA parameter linear1.parametrizations.weight.original
Freezing non-LoRA parameter linear2.bias
Freezing non-LoRA parameter linear2.parametrizations.weight.original
Freezing non-LoRA parameter linear3.bias
Freezing non-LoRA parameter linear3.parametrizations.weight.original


Epoch 1:   0%|          | 0/595 [00:00<?, ?it/s]

Epoch 1:  60%|█████▉    | 595/1000 [00:07<00:05, 80.67it/s, loss=0.000315]

torch.Size([9, 1, 28, 28])





In [None]:
# Verifying that forzen params are still unchanged 

assert torch.all(net.linear1.parametrizations.weight.orginal == original_weights['linear1.weight'])
assert torch.all(net.linear2.parametrizations.weight.orginal == original_weights['linear2.weight'])
assert torch.all(net.linear3.parametrizations.weight.orginal == original_weights['linear3.weight'])


enable_disable_lora(enabled=True)
# The new linear1.weight is obtained by the 'forward' function of our LoRA paramtrization
# The original weights have been moved to net.linear1.parametrization.weight.original 


In [73]:
enable_disable_lora(enabled=False)
test()

Testing: 100%|██████████| 1000/1000 [00:10<00:00, 92.81it/s]

Accuracy: 0.968
wrong counts for the digit 0 : 13
wrong counts for the digit 1 : 18
wrong counts for the digit 2 : 44
wrong counts for the digit 3 : 22
wrong counts for the digit 4 : 24
wrong counts for the digit 5 : 43
wrong counts for the digit 6 : 16
wrong counts for the digit 7 : 32
wrong counts for the digit 8 : 47
wrong counts for the digit 9 : 60





In [74]:
enable_disable_lora(enabled=True)
test()

Testing: 100%|██████████| 1000/1000 [00:15<00:00, 63.41it/s]

Accuracy: 0.322
wrong counts for the digit 0 : 583
wrong counts for the digit 1 : 917
wrong counts for the digit 2 : 634
wrong counts for the digit 3 : 771
wrong counts for the digit 4 : 918
wrong counts for the digit 5 : 825
wrong counts for the digit 6 : 166
wrong counts for the digit 7 : 992
wrong counts for the digit 8 : 973
wrong counts for the digit 9 : 0



