In [1]:
import torch
import torchvision.datasets as datasets 
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm

* `torch:` The main PyTorch library used for deep learning.
* `torchvision.datasets:` Contains pre-built datasets, including MNIST (a dataset of handwritten digits).
* `torchvision.transforms:` Helps with data preprocessing, like converting images to tensors and normalizing them.
* `torch.nn:` The module for building neural network architectures.
* `matplotlib.pyplot:` Used for visualization (not used yet but might be used for plotting accuracy/loss).
* `tqdm:` A progress bar library for tracking the training process.

In [2]:
# Make torch deterministic
_ = torch.manual_seed(0)

This ensures reproducibility. Setting a fixed seed makes sure that the randomness in weight initialization, data shuffling, and other operations produces the same results every time the code runs.

### Defining Data Transformations

We will be training a network to classify MNIST digits and then fine-tune the network on a particular digit on which it doesn't perform well.

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))])

* `transforms.Compose([...]):` Combines multiple transformations into one.
* `transforms.ToTensor():` Converts images to PyTorch tensors (numerical representations).
* `transforms.Normalize((0.1307,), (0.3081,)):`
  * Normalizes the dataset using the mean (0.1307) and standard deviation (0.3081) (precomputed for MNIST).
  * Normalization improves convergence during training.
<br>
<br>
### Loading the MNSIT Training Dataset

In [4]:
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

* `datasets.MNIST(...):`
  * Loads the MNIST dataset from ./data (downloads if not available).
  * `train=True:` Loads the training set (60,000 images).
  * `download=True:` Downloads the dataset if it isn't found locally.
  * `transform=transform:` Applies the previously defined transformations (tensor conversion & normalization).
<br>
<br>
### Create a dataloader for the training

In [5]:
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

* **Dataloader:** Helps efficiently load batches of data for training.
* `batch_size=10:` Each batch contains 10 images.
* `shuffle=True:` Randomly shuffles the data at every epoch to prevent model overfitting to a particular order.
<br>
<br>
### Load the MNIST Test Dataset

In [6]:
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

* `mnist_testset`: Similar to the training set, but `train=False` ensures that the test set (10,000 images) is loaded instead.
<br>
<br>
### Creating a Dataloader for Testing

In [7]:
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

* `test_loader`: This dataloader will be used to evaluate the model after training.
* Again, **batch size = 10** (so we test the model in small groups).
* Shuffling in test data is generally unnecessary but might be useful for visualization.

In [8]:
# Define the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Neural Network

Create the Neural Network to classify the digits, making it overly complicated to better show the power of LoRA

In [9]:
class RichBoyNet(nn.Module):
    def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
        super(RichBoyNet,self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1) 
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) 
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

net = RichBoyNet().to(device)

* The class `RichBoyNet` inherits from nn.Module, meaning it is a PyTorch neural network.
* The __init__ function initializes three fully connected layers (`nn.Linear`):
  * `self.linear1:` Takes input pixels (28×28 = 784) and maps them to 1000 neurons.
  * `self.linear2:` Maps 1000 neurons to 2000 neurons.
  * `self.linear3:` Maps 2000 neurons to 10 output classes (digits 0-9).
* **ReLU Activation Function** (`self.relu = nn.ReLU()`):
  * It introduces non-linearity, allowing the network to learn complex patterns.

#### Forward Pass
1. Flatten the image: The MNIST dataset contains 28×28 grayscale images. We reshape each image from (1, 28, 28) → (1, 784) for the fully connected layers.
2. Pass through the layers: The last layer `self.linear3` outputs raw logits (scores before applying softmax).
3. No activation on the last layer: This is intentional because PyTorch’s `CrossEntropyLoss` automatically applies softmax.

Moving the Model to GPU (if available)

## Training

Train the network only for 1 epoch to simulate a complete general pre-training on the data

In [10]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    # Loss Function & Optimzer
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    # Training Loop
    total_iterations = 0
    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        # Progress Bar (TQDM)
        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit

        # Iterating Over the Dataset
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)

            # Forward Pass
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)

            # Tracking Loss
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)

            # Backward Pass & Optimization
            loss.backward()
            optimizer.step()

            # Early Stopping
            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

# Final Training Call
train(train_loader, net, epochs=1)

Epoch 1: 100%|████████████████████████████████████████████████████████| 6000/6000 [00:18<00:00, 329.46it/s, loss=0.239]


This function trains the model for a given number of epochs. The optional argument `total_iterations_limit` allows stopping early (used for simulating pre-training).
<br>
<br>
#### Loss Function & Optimizer
* `CrossEntropyLoss():`
  * This is the loss function for classification tasks.
  * It computes the difference between predicted class probabilities and actual labels.

* `Adam Optimizer:`
  * Adam is a widely used optimization algorithm.
  * **Learning rate = 0.001** (controls step size during weight updates).
<br>
<br>
#### Training Loop
* `total_iterations` tracks the total number of training batches.
* `net.train()` tells PyTorch that we are in training mode (enables dropout & batch normalization if used).
<br>
<br>
#### Progress Bar (TQDM)
TQDM provides a progress bar to monitor training. If total_iterations_limit is provided, the training stops early.
<br>
<br>
#### Iterating Over the Dataset
for data in data_iterator: Iterates through batches of training images.
<br>
<br>
#### Forward Pass
* Clears old gradients (`optimizer.zero_grad()`).
* Runs the model forward (`net(x.view(-1, 28*28))`) to get predictions.
* Computes the loss (`loss = cross_el(output, y)`).
<br>
<br>
#### Tracking Loss
Tracks the running average of the loss. Updates progress bar `set_postfix(loss=avg_loss)`.
<br>
<br>
#### Backward Pass & Optimization
* `loss.backward()` calculates gradients using backpropagation.
* `optimizer.step()` updates model parameters.
<br>
<br>
#### Early Stopping
If `total_iterations_limit` is reached, training stops early.
<br>
<br>
#### Final Training Call
The model is trained for just 1 epoch to simulate pre-training.
The goal is to later fine-tune this expensive model selectively using LoRA.

## Save the Weights

Save a copy of the original model weights before applying LoRA-based fine-tuning. This ensures that we can later verify that the original weights remain unchanged.

In [11]:
original_weights = {}
for name, param in net.named_parameters():
    original_weights[name] = param.clone().detach()

* `net.named_parameters()` iterates over all model parameters (weights & biases) and stores them by name.
* `.clone().detach()` ensures that:
  * A copy of the weights is stored (instead of just a reference).
  * The weights are detached from the computation graph (so they don’t accidentally update during training).

## Test the Model's Performance

In [12]:
def test():
    correct = 0
    total = 0

    wrong_counts = [0 for i in range(10)]

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = net(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                else:
                    wrong_counts[y[idx]] +=1
                total +=1
    print(f'Accuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_counts)):
        print(f'wrong counts for the digit {i}: {wrong_counts[i]}')

test()

Testing: 100%|████████████████████████████████████████████████████████████████████| 1000/1000 [00:01<00:00, 563.06it/s]

Accuracy: 0.957
wrong counts for the digit 0: 13
wrong counts for the digit 1: 18
wrong counts for the digit 2: 39
wrong counts for the digit 3: 81
wrong counts for the digit 4: 22
wrong counts for the digit 5: 22
wrong counts for the digit 6: 59
wrong counts for the digit 7: 62
wrong counts for the digit 8: 17
wrong counts for the digit 9: 98





* `correct` and `total:` Track total correct predictions and the total number of test samples.
* `wrong_counts:` A list of size 10 (for digits 0-9), storing how many times the model misclassifies each digit.
* `torch.no_grad():` Disables gradient computation (makes inference faster and saves memory).
* `tqdm(test_loader, desc='Testing'):` Loops through the test dataset with a progress bar.
* `x.view(-1, 784):`
  * Reshapes the image tensor to a **1D vector of size 784 (28x28 pixels)** since most neural networks require a flat input.
  * `-1` lets PyTorch automatically determine the batch size.

* `torch.argmax(i):` Returns the predicted digit (i.e., the index of the highest probability in output).

* `Comparison with y[idx]:`
  * If prediction matches the actual label → Increase correct counter.
  * If incorrect → Increase wrong_counts[y[idx]] (to track how often a digit is misclassified).
  * `Update total:` Keeps track of total predictions made.
<br>
<br>


As we can see, the network performs poorly on the digit 9. Let's fine-tune it on the digit 9


Let's visualize how many parameters are in the original network, before introducing the LoRA matrices.

In [13]:
# Print the size of the weights matrices of the network
# Save the count of the total number of parameters
total_parameters_original = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')
print(f'Total number of parameters: {total_parameters_original:,}')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10])
Total number of parameters: 2,807,010


* `total_parameters_original = 0:` Initializes a counter for the total number of parameters in the model.
* Loop through each layer (linear1, linear2, linear3):
  * `layer.weight.nelement():` Counts the total elements in the weight matrix W.
  * `layer.bias.nelement():` Counts the total elements in the bias vector B.
  * The sum of both is added to `total_parameters_original`.
<br>
<br>

Define the LoRA parameterization as described in the paper. The full detail on how PyTorch parameterizations work is here: https://pytorch.org/tutorials/intermediate/parametrizations.html

In [14]:
class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
        super().__init__()
        # Section 4.1 of the paper: 
        #   We use a random Gaussian initialization for A and zero for B, so ∆W = BA is zero at the beginning of training
        self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)
        
        # Section 4.1 of the paper: 
        #   We then scale ∆Wx by α/r , where α is a constant in r. 
        #   When optimizing with Adam, tuning α is roughly the same as tuning the learning rate if we scale the initialization appropriately. 
        #   As a result, we simply set α to the first r we try and do not tune it. 
        #   This scaling helps to reduce the need to retune hyperparameters when we vary r.
        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            # Return W + (B*A)*scale
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

* `features_in:` Number of input features (rows of W).
* `features_out:` Number of output features (columns of W).
* `rank (r):` Defines the size of the low-rank decomposition (r << min(d, k)).
* `alpha (α):` Scaling factor to control how much LoRA modifies the original weights.
* `device:` Moves parameters to CPU/GPU.
<br>
<br>
* `lora_A (r × k):` Rank-r adaptation matrix.
* `lora_B (d × r):` Projects input features into r dimensions.
* Initializes lora_A with a Gaussian distribution (mean=0, std=1).
* lora_B is initialized with zeros (ensuring no initial impact on W).
<br>
<br>
* Scales the LoRA update by α/r to balance learning rate adjustments. Setting α = r simplifies hyperparameter tuning.
<br>
<br>
* torch.matmul(self.lora_B, self.lora_A)
  * Multiplies the two low-rank matrices to form a full-sized weight update.
  * Shape: `(d × r) × (r × k) → (d × k)`
* `.view(original_weights.shape)`
  * Reshapes the update to match the original weight matrix.
* Multiplies by self.scale to adjust learning rate.
<br>
<br>

* Returns the modified weight matrix `W_new = W + (B * A) * scale`.


### Add the parameterization to our network.

In [15]:
import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    # Only add the parameterization to the weight matrix, ignore the Bias

    # From section 4.2 of the paper:
    #   We limit our study to only adapting the attention weights for downstream tasks and freeze the MLP modules (so they are not trained in downstream tasks) both for simplicity and parameter-efficiency.
    #   [...]
    #   We leave the empirical investigation of [...], and biases to a future work.
    
    features_in, features_out = layer.weight.shape
    return LoRAParametrization(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )

parametrize.register_parametrization(
    net.linear1, "weight", linear_layer_parameterization(net.linear1, device)
)
parametrize.register_parametrization(
    net.linear2, "weight", linear_layer_parameterization(net.linear2, device)
)
parametrize.register_parametrization(
    net.linear3, "weight", linear_layer_parameterization(net.linear3, device)
)


def enable_disable_lora(enabled=True):
    for layer in [net.linear1, net.linear2, net.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled

This toggles LoRA on or off by setting the enabled flag.
If `enabled=False`, the original weights are used without LoRA updates.

### Display the number of parameters added by LoRA.

In [16]:
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
    )
# The non-LoRA parameters count must match the original network
assert total_parameters_non_lora == total_parameters_original
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000]) + Lora_A: torch.Size([1, 784]) + Lora_B: torch.Size([1000, 1])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000]) + Lora_A: torch.Size([1, 1000]) + Lora_B: torch.Size([2000, 1])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10]) + Lora_A: torch.Size([1, 2000]) + Lora_B: torch.Size([10, 1])
Total number of parameters (original): 2,807,010
Total number of parameters (original + LoRA): 2,813,804
Parameters introduced by LoRA: 6,794
Parameters incremment: 0.242%


Freeze all the parameters of the original network and only fine tuning the ones introduced by LoRA. Then fine-tune the model on the digit 9 and only for 100 batches.

In [17]:
# Freeze the non-Lora parameters
for name, param in net.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

# Load the MNIST dataset again, by keeping only the digit 9
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Train the network with LoRA only on the digit 9 and only for 100 batches (hoping that it would improve the performance on the digit 9)
train(train_loader, net, epochs=1, total_iterations_limit=100)

Freezing non-LoRA parameter linear1.bias
Freezing non-LoRA parameter linear1.parametrizations.weight.original
Freezing non-LoRA parameter linear2.bias
Freezing non-LoRA parameter linear2.parametrizations.weight.original
Freezing non-LoRA parameter linear3.bias
Freezing non-LoRA parameter linear3.parametrizations.weight.original


Epoch 1:  99%|██████████████████████████████████████████████████████████▍| 99/100 [00:00<00:00, 251.54it/s, loss=0.125]


* This freezes all parameters except LoRA-specific ones (lora_A, lora_B). This ensures that only the LoRA components are trained while keeping the base model unchanged.
* Loads the MNIST dataset but keeps only the images of digit 9.
* This dataset will be used to fine-tune LoRA to improve the model’s recognition of the digit 9.
* Trains the model only for 100 batches, optimizing only the LoRA parameters.


Verify that the fine-tuning didn't alter the original weights, but only the ones introduced by LoRA.

In [18]:
# Check that the frozen parameters are still unchanged by the finetuning
assert torch.all(net.linear1.parametrizations.weight.original == original_weights['linear1.weight'])
assert torch.all(net.linear2.parametrizations.weight.original == original_weights['linear2.weight'])
assert torch.all(net.linear3.parametrizations.weight.original == original_weights['linear3.weight'])

enable_disable_lora(enabled=True)
# The new linear1.weight is obtained by the "forward" function of our LoRA parametrization
# The original weights have been moved to net.linear1.parametrizations.weight.original
# More info here: https://pytorch.org/tutorials/intermediate/parametrizations.html#inspecting-a-parametrized-module
assert torch.equal(net.linear1.weight, net.linear1.parametrizations.weight.original + (net.linear1.parametrizations.weight[0].lora_B @ net.linear1.parametrizations.weight[0].lora_A) * net.linear1.parametrizations.weight[0].scale)

enable_disable_lora(enabled=False)
# If we disable LoRA, the linear1.weight is the original one
assert torch.equal(net.linear1.weight, original_weights['linear1.weight'])

* Checks that the original non-LoRA parameters have not changed after fine-tuning.
* Ensures that LoRA is applied correctly.
* The new weight matrix is computed as: **`W = W_og + (B x A) x scale`**
* Disables LoRA and checks that the weight matrix returns to its original state.
* This confirms that LoRA operates as an additive adjustment rather than modifying the base weights permanently.


### Test the network with LoRA enabled
The digit 9 should be classified better.

In [19]:
# Test with LoRA enabled
enable_disable_lora(enabled=True)
test()

Testing: 100%|████████████████████████████████████████████████████████████████████| 1000/1000 [00:02<00:00, 454.58it/s]

Accuracy: 0.891
wrong counts for the digit 0: 30
wrong counts for the digit 1: 16
wrong counts for the digit 2: 60
wrong counts for the digit 3: 150
wrong counts for the digit 4: 434
wrong counts for the digit 5: 58
wrong counts for the digit 6: 73
wrong counts for the digit 7: 205
wrong counts for the digit 8: 53
wrong counts for the digit 9: 13





### Test the network with LoRA disabled
The accuracy and errors counts must be the same as the original network.

In [20]:
# Test with LoRA disabled
enable_disable_lora(enabled=False)
test()

Testing: 100%|████████████████████████████████████████████████████████████████████| 1000/1000 [00:01<00:00, 503.61it/s]

Accuracy: 0.957
wrong counts for the digit 0: 13
wrong counts for the digit 1: 18
wrong counts for the digit 2: 39
wrong counts for the digit 3: 81
wrong counts for the digit 4: 22
wrong counts for the digit 5: 22
wrong counts for the digit 6: 59
wrong counts for the digit 7: 62
wrong counts for the digit 8: 17
wrong counts for the digit 9: 98



