[Video Link](https://www.youtube.com/watch?v=PXWYUTMt-AU)

In [1]:
import torch
import numpy as np

# SVD

In [2]:
# Create a low rank matrix
d, k = 10, 10
w_rank = 2

W = torch.rand(d, w_rank) @ torch.rand(w_rank, k)
print(f"Shape of W: {W.shape}")
print(f"rank of W: {np.linalg.matrix_rank(W)}")

Shape of W: torch.Size([10, 10])
rank of W: 2


In [3]:
# Apply SVD
U, S, V = torch.svd(W)
print(f"Shape of U: {U.shape} | Shape of S: {S.shape} | Shape of V: {V.shape}")

U_r = U[:, :w_rank]
S_r = torch.diag(S[:w_rank])
V_r = V[:, :w_rank].T

B = U_r @ S_r
A = V_r

print(f"Shape of B: {B.shape}")
print(f"Shape of A: {A.shape}")

Shape of U: torch.Size([10, 10]) | Shape of S: torch.Size([10]) | Shape of V: torch.Size([10, 10])
Shape of B: torch.Size([10, 2])
Shape of A: torch.Size([2, 10])


In [4]:
# For the same input, check the results using both the original and the low rank matrices
bias = torch.rand(d)
x = torch.rand(k)

y = W @ x + bias
y_prime = B @ A @ x + bias

print(f"y: {y}")
print(f"y_prime: {y_prime}")
print(f"Difference: {torch.sum(y - y_prime)}")

print(f"Total number of params in W: {W.numel()}")
print(f"Total number of params in A & B: {A.numel() + B.numel()}")

y: tensor([2.2722, 1.7504, 2.8158, 2.1827, 2.3436, 1.9618, 1.6580, 2.9906, 1.9360,
        2.2995])
y_prime: tensor([2.2722, 1.7504, 2.8158, 2.1827, 2.3436, 1.9618, 1.6580, 2.9906, 1.9360,
        2.2995])
Difference: -4.649162292480469e-06
Total number of params in W: 100
Total number of params in A & B: 40


# LoRA

In [5]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import tqdm

seed = torch.manual_seed(0)

In [6]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

device = "mps"

In [7]:
# Create unnecessarily large model
class UnnecessarilyLargeModel(nn.Module):
    def __init__(self, hidden1 = 1000, hidden2 = 2000):
        super().__init__()
        self.linear1 = nn.Linear(28*28, hidden1)
        self.linear2 = nn.Linear(hidden1, hidden2)
        self.linear3 = nn.Linear(hidden2, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

model = UnnecessarilyLargeModel().to(device)

In [8]:
# train for a single epoch
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model.train()
total_loss = 0

progress_bar = tqdm.tqdm(train_loader, desc='Training')

for i, (data, target) in enumerate(progress_bar):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()
    total_loss += loss.item()
    progress_bar.set_postfix({'loss': total_loss/(i+1)})

Training:   0%|          | 0/938 [00:00<?, ?it/s]

Training: 100%|██████████| 938/938 [00:06<00:00, 139.27it/s, loss=0.194]


In [9]:
original_weights = {name: param.clone().detach() for name, param in model.named_parameters()}

In [10]:
original_weights = {}
for name, param in model.named_parameters():
    original_weights[name] = param.clone().detach()

In [11]:
correct = 0
total = 0
wrong_counts = [0 for i in range(10)]

with torch.no_grad():
    for data in test_loader:
        x, y = data
        x = x.to(device)
        y = y.to(device)
        output = model(x.view(-1, 784))
        for idx, i in enumerate(output):
            if torch.argmax(i) == y[idx]:
                correct +=1
            else:
                wrong_counts[y[idx]] +=1
            total +=1
print(f'Accuracy: {round(correct/total, 3)}')
for i in range(len(wrong_counts)):
    print(f'wrong counts for the digit {i}: {wrong_counts[i]}')

Accuracy: 0.967
wrong counts for the digit 0: 13
wrong counts for the digit 1: 15
wrong counts for the digit 2: 31
wrong counts for the digit 3: 50
wrong counts for the digit 4: 22
wrong counts for the digit 5: 14
wrong counts for the digit 6: 23
wrong counts for the digit 7: 49
wrong counts for the digit 8: 51
wrong counts for the digit 9: 58


## LoRA Parameter

In [12]:
class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank = 1, alpha = 1, device = "mps"):
        super().__init__()
        self.lora_A = nn.Parameter(torch.randn(rank, features_out, device=device))
        self.lora_B = nn.Parameter(torch.randn(features_in, rank, device=device))
        # In the paper A is initialized with a normal distribution
        nn.init.normal_(self.lora_A, mean=0, std=1)

        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            return original_weights + (self.lora_B @ self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

In [13]:
# Add parameterization to the model (https://pytorch.org/tutorials/intermediate/parametrizations.html)
# Instead of accessing the weights, it is now to access this function 
def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    features_in, features_out = layer.weight.shape
    return LoRAParametrization(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )

torch.nn.utils.parametrize.register_parametrization(
    model.linear1, "weight", linear_layer_parameterization(model.linear1, device)
)
torch.nn.utils.parametrize.register_parametrization(
    model.linear2, "weight", linear_layer_parameterization(model.linear2, device)
)
torch.nn.utils.parametrize.register_parametrization(
    model.linear3, "weight", linear_layer_parameterization(model.linear3, device)
)

def enable_disable_lora(enabled=True):
    for layer in [model.linear1, model.linear2, model.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled

In [14]:
# get total number of parameters (original model)
print(f"Total number of parameters in the original model: {sum(p.numel() for p in model.parameters()):,}")
# get total number of parameters (LoRA model)
total_parameters_lora = sum([layer.parametrizations['weight'][0].lora_A.numel() + layer.parametrizations['weight'][0].lora_B.numel() for layer in [model.linear1, model.linear2, model.linear3]])
print(f"Total number of parameters in the LoRA model: {total_parameters_lora:,}")
# percentage
print(f"Percentage of parameters in the LoRA model: {total_parameters_lora / sum(p.numel() for p in model.parameters()) * 100:.2f}%")

Total number of parameters in the original model: 2,813,804
Total number of parameters in the LoRA model: 6,794
Percentage of parameters in the LoRA model: 0.24%


## Train LoRA model

In [15]:
# Freeze the non-Lora parameters
for name, param in model.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

# Load the MNIST dataset again, by keeping only the digit 9
exclude_indices = train_dataset.targets == 9
train_dataset.data = train_dataset.data[exclude_indices]
train_dataset.targets = train_dataset.targets[exclude_indices]
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10, shuffle=True)

# train for a single epoch
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model.train()
total_loss = 0

progress_bar = tqdm.tqdm(train_loader, desc='Training')

for i, (data, target) in enumerate(progress_bar):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()
    total_loss += loss.item()
    progress_bar.set_postfix({'loss': total_loss/(i+1)})
    if i > 100:
        break

Freezing non-LoRA parameter linear1.bias
Freezing non-LoRA parameter linear1.parametrizations.weight.original
Freezing non-LoRA parameter linear2.bias
Freezing non-LoRA parameter linear2.parametrizations.weight.original
Freezing non-LoRA parameter linear3.bias
Freezing non-LoRA parameter linear3.parametrizations.weight.original


Training:  17%|█▋        | 101/595 [00:00<00:03, 135.06it/s, loss=1.26]


In [16]:
# Check that the frozen parameters are still unchanged by the fimodeluning
assert torch.all(model.linear1.parametrizations.weight.original == original_weights['linear1.weight'])
assert torch.all(model.linear2.parametrizations.weight.original == original_weights['linear2.weight'])
assert torch.all(model.linear3.parametrizations.weight.original == original_weights['linear3.weight'])

assert torch.equal(model.linear1.weight, model.linear1.parametrizations.weight.original + (model.linear1.parametrizations.weight[0].lora_B @ model.linear1.parametrizations.weight[0].lora_A) * model.linear1.parametrizations.weight[0].scale)

In [17]:
# Test with LoRA
enable_disable_lora(enabled=True)
correct = 0
total = 0
wrong_counts = [0 for i in range(10)]

with torch.no_grad():
    for data in test_loader:
        x, y = data
        x = x.to(device)
        y = y.to(device)
        output = model(x.view(-1, 784))
        for idx, i in enumerate(output):
            if torch.argmax(i) == y[idx]:
                correct +=1
            else:
                wrong_counts[y[idx]] +=1
            total +=1
print(f'Accuracy: {round(correct/total, 3)}')
for i in range(len(wrong_counts)):
    print(f'wrong counts for the digit {i}: {wrong_counts[i]}')

Accuracy: 0.103
wrong counts for the digit 0: 980
wrong counts for the digit 1: 1131
wrong counts for the digit 2: 1032
wrong counts for the digit 3: 1009
wrong counts for the digit 4: 981
wrong counts for the digit 5: 891
wrong counts for the digit 6: 957
wrong counts for the digit 7: 1016
wrong counts for the digit 8: 973
wrong counts for the digit 9: 0


In [18]:
# Test without LoRA
enable_disable_lora(enabled=False)
correct = 0
total = 0
wrong_counts = [0 for i in range(10)]

with torch.no_grad():
    for data in test_loader:
        x, y = data
        x = x.to(device)
        y = y.to(device)
        output = model(x.view(-1, 784))
        for idx, i in enumerate(output):
            if torch.argmax(i) == y[idx]:
                correct +=1
            else:
                wrong_counts[y[idx]] +=1
            total +=1
print(f'Accuracy: {round(correct/total, 3)}')
for i in range(len(wrong_counts)):
    print(f'wrong counts for the digit {i}: {wrong_counts[i]}')

Accuracy: 0.967
wrong counts for the digit 0: 13
wrong counts for the digit 1: 15
wrong counts for the digit 2: 31
wrong counts for the digit 3: 50
wrong counts for the digit 4: 22
wrong counts for the digit 5: 14
wrong counts for the digit 6: 23
wrong counts for the digit 7: 49
wrong counts for the digit 8: 51
wrong counts for the digit 9: 58
