In [None]:
#lets see basic principle of rank decomposition using SVD

In [2]:
import torch
import numpy as np
_ = torch.manual_seed(0)

In [3]:
d, k = 10, 10

# This way we can generate a rank-deficient matrix
W_rank = 2
W = torch.randn(d,W_rank) @ torch.randn(W_rank,k)
print(W)

tensor([[-1.0797,  0.5545,  0.8058, -0.7140, -0.1518,  1.0773,  2.3690,  0.8486,
         -1.1825, -3.2632],
        [-0.3303,  0.2283,  0.4145, -0.1924, -0.0215,  0.3276,  0.7926,  0.2233,
         -0.3422, -0.9614],
        [-0.5256,  0.9864,  2.4447, -0.0290,  0.2305,  0.5000,  1.9831, -0.0311,
         -0.3369, -1.1376],
        [ 0.7900, -1.1336, -2.6746,  0.1988, -0.1982, -0.7634, -2.5763, -0.1696,
          0.6227,  1.9294],
        [ 0.1258,  0.1458,  0.5090,  0.1768,  0.1071, -0.1327, -0.0323, -0.2294,
          0.2079,  0.5128],
        [ 0.7697,  0.0050,  0.5725,  0.6870,  0.2783, -0.7818, -1.2253, -0.8533,
          0.9765,  2.5786],
        [ 1.4157, -0.7814, -1.2121,  0.9120,  0.1760, -1.4108, -3.1692, -1.0791,
          1.5325,  4.2447],
        [-0.0119,  0.6050,  1.7245,  0.2584,  0.2528, -0.0086,  0.7198, -0.3620,
          0.1865,  0.3410],
        [ 1.0485, -0.6394, -1.0715,  0.6485,  0.1046, -1.0427, -2.4174, -0.7615,
          1.1147,  3.1054],
        [ 0.9088,  

In [4]:
W_rank = np.linalg.matrix_rank(W)
print(f'Rank of W: {W_rank}')


Rank of W: 2


In [5]:
# Perform SVD on W (W = UxSxV^T)
U, S, V = torch.svd(W)

# For rank-r factorization, keep only the first r singular values (and corresponding columns of U and V)
U_r = U[:, :W_rank]
print(U_r.shape)
S_r = torch.diag(S[:W_rank])
V_r = V[:, :W_rank].t()  # Transpose V_r to get the right dimensions

# Compute B = U_r * S_r and A = V_r
B = U_r @ S_r
A = V_r
print(f'Shape of B: {B.shape}')
print(f'Shape of A: {A.shape}')

torch.Size([10, 2])
Shape of B: torch.Size([10, 2])
Shape of A: torch.Size([2, 10])


In [6]:
# Generate random bias and input
bias = torch.randn(d)
x = torch.randn(d)

# Compute y = Wx + bias
y = W @ x + bias
# Compute y' = (B*A)x + bias
y_prime = (B @ A) @ x + bias

print("Original y using W:\n", y)
print("")
print("y' computed using BA:\n", y_prime)

Original y using W:
 tensor([ 7.2684e+00,  2.3162e+00,  7.7151e+00, -1.0446e+01, -8.1639e-03,
        -3.7270e+00, -1.1146e+01,  2.0207e+00, -9.6258e+00, -4.1163e+00])

y' computed using BA:
 tensor([ 7.2684e+00,  2.3162e+00,  7.7151e+00, -1.0446e+01, -8.1638e-03,
        -3.7270e+00, -1.1146e+01,  2.0207e+00, -9.6258e+00, -4.1163e+00])


In [7]:
print("Total parameters of W: ", W.nelement())
print("Total parameters of B and A: ", B.nelement() + A.nelement())

Total parameters of W:  100
Total parameters of B and A:  40


In [8]:
import torch
import torchvision.datasets as datasets 
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm

In [9]:
# Make torch deterministic
_ = torch.manual_seed(0)

Load MNIST datset

In [10]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# Define the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [02:35<00:00, 63744.72it/s] 


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 134691.74it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:02<00:00, 571821.85it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 1114717.89it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [38]:
#create model 
class Ubaidsnet(nn.Module):
    def __init__(self,hidden_size1=800,hidden_size_2=1500,hidden_size3=2500):
        super(Ubaidsnet,self).__init__()
        self.layer1=nn.Linear(28*28,hidden_size1)
        self.layer2=nn.Linear(hidden_size1,hidden_size_2)
        self.layer3=nn.Linear(hidden_size_2,10)
        self.relu=nn.ReLU()
    def forward(self,img):
        x=img.view(-1,28*28)
        x=self.relu(self.layer1(x))
        x=self.relu(self.layer2(x))
        x=self.layer3(x)
        return x

model=Ubaidsnet().to(device) 
print(list(model.parameters()))  

[Parameter containing:
tensor([[-0.0274, -0.0074,  0.0302,  ...,  0.0198,  0.0153, -0.0281],
        [-0.0025, -0.0151,  0.0252,  ...,  0.0207, -0.0247, -0.0227],
        [ 0.0301,  0.0009, -0.0177,  ...,  0.0029, -0.0068,  0.0271],
        ...,
        [-0.0193,  0.0080, -0.0290,  ..., -0.0092, -0.0289, -0.0240],
        [-0.0012,  0.0114,  0.0177,  ..., -0.0343,  0.0165, -0.0319],
        [-0.0002, -0.0023,  0.0319,  ...,  0.0174,  0.0070,  0.0092]],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([-2.1759e-02,  2.2536e-02, -1.5426e-02, -2.9577e-02,  2.8823e-02,
        -1.7362e-02, -1.3010e-02,  1.7051e-02, -1.6580e-03,  3.2968e-02,
         2.2185e-02,  1.3882e-03, -2.7509e-02,  2.2329e-02, -3.5247e-02,
         1.8279e-02,  2.3810e-04,  2.0273e-02,  1.1014e-03,  2.6392e-02,
        -1.7021e-02,  9.5358e-03,  3.2214e-02,  1.4676e-02,  2.3283e-02,
         5.2577e-03, -6.0180e-03,  2.5439e-02, -3.1110e-02,  3.6793e-03,
        -1.9304e-02,  1.2575e-02,  1.0

In [39]:
#traing 
def train(train_loader, model, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        model.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = model(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

train(train_loader, model, epochs=1)

Epoch 1: 100%|██████████| 6000/6000 [00:18<00:00, 322.92it/s, loss=0.234]


Lets  make our own classifier to predict digits on mNIST dataset

Now Lets See LORA 

Lets save the origionalweights

In [41]:
origional_weights={}
for name,param in model.named_parameters():
    origional_weights[name]=param.clone().detach()

In [67]:
def test():
    correct = 0
    total = 0

    wrong_counts = [0 for i in range(10)]

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                else:
                    wrong_counts[y[idx]] +=1
                total +=1
    print(f'Accuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_counts)):
        print(f'wrong counts for the digit {i}: {wrong_counts[i]}')

test()  #c;leary digt 2 has been wrongly classfiofeied lets fine tine only for digit 2 class 

Testing: 100%|██████████| 1000/1000 [00:02<00:00, 429.89it/s]

Accuracy: 0.963
wrong counts for the digit 0: 23
wrong counts for the digit 1: 12
wrong counts for the digit 2: 64
wrong counts for the digit 3: 37
wrong counts for the digit 4: 42
wrong counts for the digit 5: 37
wrong counts for the digit 6: 28
wrong counts for the digit 7: 54
wrong counts for the digit 8: 29
wrong counts for the digit 9: 49





Let's visualize how many parameters are in the original network, before introducing the LoRA matrices.

In [46]:
total_param_org=0
for index,layer in enumerate([model.layer1,model.layer2,model.layer3]):
    total_param_org+=layer.weight.nelement()+layer.bias.nelement()
    print(f'layer{index+1}: W:{layer.weight.shape}+B:{layer.bias.shape}')
print(f'Total pmts:{total_param_org/10**6:,}M')


layer1: W:torch.Size([800, 784])+B:torch.Size([800])
layer2: W:torch.Size([1500, 800])+B:torch.Size([1500])
layer3: W:torch.Size([10, 1500])+B:torch.Size([10])
Total pmts:1.84451M


In [57]:
class LoraReparametrization(nn.Module):
    def __init__(self,features_in,features_out,rank=1,alpha=1,device='cpu'):
        super().__init__()
        #we use random intialization for A and zero for B, so ∆W = BA is zero at the beginning 
        self.loraA=nn.Parameter(torch.zeros((rank,features_in))).to(device)
        self.loraB=nn.Parameter(torch.zeros((features_out,rank))).to(device)
        nn.init.normal_(self.loraA,mean=0,std=1)
        
       # Section 4.1 of the paper: 
        #   We then scale ∆Wx by α/r , where α is a constant in r. 
        #   When optimizing with Adam, tuning α is roughly the same as tuning the learning rate if we scale the initialization appropriately. 
        #   As a result, we simply set α to the first r we try and do not tune it. 
        #   This scaling helps to reduce the need to retune hyperparameters when we vary r.
        self.scale=alpha/rank
        self.enabled=True
    
    def forward(self,orignal_weights):
        if self.enabled:
            # Return W + (B*A)*scale
            return orignal_weights+torch.matmul(self.loraB,self.loraA).view(orignal_weights.shape)*self.scale# 2we alter the orginal weights by using B and A
        else:
            return  orignal_weights
        



Add the parameterization to our network.

In [98]:
import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    # Only add the parameterization to the weight matrix, ignore the Bias

    # From section 4.2 of the paper:
    #   We limit our study to only adapting the attention weights for downstream tasks and freeze the MLP modules (so they are not trained in downstream tasks) both for simplicity and parameter-efficiency.
    #   [...]
    #   We leave the empirical investigation of [...], and biases to a future work.
    
    features_in, features_out = layer.weight.shape
    return LoraReparametrization(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )

parametrize.register_parametrization(
    model.layer1, "weight", linear_layer_parameterization(model.layer1, device)#to replace the weights with the W+delw weights
)
parametrize.register_parametrization(
    model.layer2, "weight", linear_layer_parameterization(model.layer2, device)
)
parametrize.register_parametrization(
    model.layer3, "weight", linear_layer_parameterization(model.layer3, device)
)


def enable_disable_lora(enabled=True):
    for layer in [model.layer1, model.layer2, model.layer3]:
        layer.parametrizations["weight"][0].enabled = enabled

In [99]:
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([model.layer1, model.layer2, model.layer3]):
    total_parameters_lora += layer.parametrizations["weight"][0].loraA.nelement() + layer.parametrizations["weight"][0].loraB.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].loraA.shape} + Lora_B: {layer.parametrizations["weight"][0].loraB.shape}'
    )
# The non-LoRA parameters count must match the original network
assert total_parameters_non_lora == total_param_org
print(f'Total number of parameters (original): {total_parameters_non_lora/10**6:,}')
print(f'Total number of parameters (original + LoRA): {(total_parameters_lora + total_parameters_non_lora)/10**6:,}M')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')#arameters introduced by LoRA: 5,394 only thes parameters would be trained by the lORA other wont be
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')

Layer 1: W: torch.Size([800, 784]) + B: torch.Size([800]) + Lora_A: torch.Size([1, 800]) + Lora_B: torch.Size([784, 1])
Layer 2: W: torch.Size([1500, 800]) + B: torch.Size([1500]) + Lora_A: torch.Size([1, 1500]) + Lora_B: torch.Size([800, 1])
Layer 3: W: torch.Size([10, 1500]) + B: torch.Size([10]) + Lora_A: torch.Size([1, 10]) + Lora_B: torch.Size([1500, 1])
Total number of parameters (original): 1.84451
Total number of parameters (original + LoRA): 1.849904M
Parameters introduced by LoRA: 5,394
Parameters incremment: 0.292%


Freeze all the parameters of the original network and only fine tuning the ones introduced by LoRA. Then fine-tune the model on the digit 2 and only for 100 batches.

In [100]:
# Freeze the non-Lora parameters
for name, param in model.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

# Load the MNIST dataset again, by keeping only the digit 9
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 2
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Train the network with LoRA only on the digit 9 and only for 100 batches (hoping that it would improve the performance on the digit 9)
train(train_loader, model, epochs=1, total_iterations_limit=100)

Freezing non-LoRA parameter layer1.bias
Freezing non-LoRA parameter layer1.parametrizations.weight.original
Freezing non-LoRA parameter layer2.bias
Freezing non-LoRA parameter layer2.parametrizations.weight.original
Freezing non-LoRA parameter layer3.bias
Freezing non-LoRA parameter layer3.parametrizations.weight.original


Epoch 1:  99%|█████████▉| 99/100 [00:00<00:00, 242.60it/s, loss=0.184]


erify that the fine-tuning didn't alter the original weights, but only the ones introduced by LoRA.

In [None]:
# Check that the frozen parameters are still unchanged by the finetuning
assert torch.all(model.layer1.parametrizations.weight.original == origional_weights['layer1.weight'])
assert torch.all(model.layer2.parametrizations.weight.original == origional_weights['layer2.weight'])
assert torch.all(model.layer3.parametrizations.weight.original == origional_weights['layer3.weight'])

enable_disable_lora(enabled=True)
# The new linear1.weight is obtained by the "forward" function of our LoRA parametrization
# The original weights have been moved to net.linear1.parametrizations.weight.original
# More info here: https://pytorch.org/tutorials/intermediate/parametrizations.html#inspecting-a-parametrized-module
assert torch.equal(model.layer1.weight, model.layer1.parametrizations.weight.original + (model.layer1.parametrizations.weight[0].loraB @ model.layer1.parametrizations.weight[0].loraA) * model.layer1.parametrizations.weight[0].scale)#w=w+BA*scale

enable_disable_lora(enabled=False)
# If we disable LoRA, the linear1.weight is the original one
assert torch.equal(model.layer1.weight, origional_weights['linear1.weight'])

In [102]:
# Test with LoRA enabled
enable_disable_lora(enabled=True)
test()

Testing: 100%|██████████| 1000/1000 [00:02<00:00, 398.53it/s]

Accuracy: 0.963
wrong counts for the digit 0: 23
wrong counts for the digit 1: 12
wrong counts for the digit 2: 64
wrong counts for the digit 3: 37
wrong counts for the digit 4: 42
wrong counts for the digit 5: 37
wrong counts for the digit 6: 28
wrong counts for the digit 7: 54
wrong counts for the digit 8: 29
wrong counts for the digit 9: 49





In [103]:
# Test with LoRA disabled
enable_disable_lora(enabled=False)
test()

Testing: 100%|██████████| 1000/1000 [00:02<00:00, 425.48it/s]

Accuracy: 0.963
wrong counts for the digit 0: 23
wrong counts for the digit 1: 12
wrong counts for the digit 2: 64
wrong counts for the digit 3: 37
wrong counts for the digit 4: 42
wrong counts for the digit 5: 37
wrong counts for the digit 6: 28
wrong counts for the digit 7: 54
wrong counts for the digit 8: 29
wrong counts for the digit 9: 49



