# 📊 Overview of the Implementation of LoRA on an ANN

In this notebook, we will explore the implementation of **Low-Rank Adaptation (LoRA)** on an Artificial Neural Network (ANN). LoRA is a technique designed to enhance the efficiency of model training by introducing low-rank parameterization, allowing for effective fine-tuning with fewer parameters.

## 🔍 Objectives:
- **Understand LoRA**: Grasp the fundamental concepts behind Low-Rank Adaptation and its advantages in parameter-efficient training.
- **Setup the ANN**: Implement a basic ANN structure that we will adapt using LoRA.
- **Parameterization**: Learn how to parameterize the weights of the network to include LoRA.
- **Training**: Train the modified network and observe the impact on performance with reduced parameters.
- **Analysis**: Compare the performance and parameters of the original network versus the LoRA-adapted network.

## 🚀 Let's dive in and implement LoRA to make our ANN more efficient!


In [1]:
import torch
import torchvision.datasets as datasets 
import torchvision.transforms as transforms
import torch.nn as nn
from tqdm import tqdm

### 📥 Loading the Dataset

In this section, we will load the **MNIST dataset**, a widely used dataset for training image processing systems. It consists of handwritten digits from 0 to 9, which we will use to train our ANN with the LoRA technique.

#### 🔄 Data Transformations:
- **Tensor Conversion**: The images will be converted to tensors, which are essential for operations in PyTorch.
- **Normalization**: Each image will be normalized using the optimal mean and standard deviation values for better performance in machine learning and deep learning models.

#### 🔄 Data Loaders:
- We create **data loaders** for both the training and test sets, allowing us to efficiently batch and shuffle our data during training. Each batch will contain 10 images.

#### 🚀 Leveraging GPU:
- We check if a GPU is available and set our device accordingly. Using a GPU will significantly speed up the training process, making it easier to work with larger models and datasets.

### Let's prepare the data for training our model with LoRA!


In [2]:
transform = transforms.Compose([
    transforms.ToTensor(), # converting to tensors
    transforms.Normalize((0.1307,), (0.3081,)) # performing normalization on the data which is optimal in ML or DL
])

# we would be using the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# creating batch norm
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# trying to leverage my baby GPU hahahaha ;)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


### 🧠 "Complex" ANN for This Task

In this section, we define a **Complex Artificial Neural Network (ANN)** designed to handle the MNIST digit classification task. This architecture consists of multiple layers, allowing it to learn intricate patterns in the data.

#### 🔍 Network Architecture:
- **Input Layer**: The network takes images of size \(28 \times 28\) pixels, flattened into a vector of size 784.
- **Hidden Layers**:
  - **First Hidden Layer**: 1000 neurons
  - **Second Hidden Layer**: 1500 neurons
  - **Third Hidden Layer**: 1300 neurons
- **Output Layer**: The final layer outputs predictions for 10 classes, corresponding to the digits 0 through 9.

#### ⚙️ Activation Function:
- **ReLU Activation**: The Rectified Linear Unit (ReLU) is applied after each hidden layer to introduce non-linearity, helping the model to learn complex relationships in the data.

### 🚀 Moving to Device:
- The model is transferred to the specified device (GPU or CPU) for efficient computation.

With this architecture, we aim to leverage the power of deep learning to accurately classify handwritten digits from the MNIST dataset!


In [3]:
class NeuralNetwork(nn.Module):
    def __init__(self, hidden_layer_1 = 1000,hidden_layer_2 = 1500, hidden_layer_3 = 1300):
        super(NeuralNetwork,self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_layer_1)
        self.linear2 = nn.Linear(hidden_layer_1, hidden_layer_2)
        self.linear3 = nn.Linear(hidden_layer_2, hidden_layer_3)
        self.linear4 = nn.Linear(hidden_layer_3, 10)
        self.relu = nn.ReLU()
        
    def forward(self,img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.relu(self.linear3(x))
        x = self.linear4(x)
        return x

model = NeuralNetwork().to(device)

### ⚙️ Model Training

In this section, we implement the training routine for our Complex ANN using the **Adam** optimizer and **Cross Entropy Loss**. The training process will involve iterating over the training dataset, updating the model parameters, and monitoring the loss.

#### 🔑 Key Components:

- **Optimizer**: 
  - **Adam**: A popular optimization algorithm that adapts the learning rate for each parameter, improving convergence speed.
  
- **Loss Function**: 
  - **Cross Entropy Loss**: Suitable for multi-class classification problems, it measures the performance of the model by comparing the predicted class probabilities with the true labels.

#### 📈 Training Loop:
1. **Epochs**: The model is trained for a specified number of epochs.
2. **Data Loader**: The training data is loaded in batches for efficient processing.
3. **Forward Pass**:
   - The input images are reshaped and passed through the model to obtain predictions.
4. **Loss Calculation**:
   - The loss is computed by comparing the model’s predictions with the actual labels.
5. **Backward Pass**:
   - The gradients are calculated through backpropagation, and the optimizer updates the model parameters accordingly.
6. **Monitoring**:
   - The average loss is calculated and displayed in real-time using a progress bar.

This training routine allows the model to learn from the data effectively, improving its performance on the MNIST classification task. 


In [4]:
def train(train_loader, model, epochs = None, total_iterations_limit = None):
    # optimizer and loss function
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_function = nn.CrossEntropyLoss() # since this is a classification problem.

    total_iterations = 0  # Keep track of how many total iterations we've done

    for epoch in range(epochs):
        model.train()

        loss_sum = 0  # Sum of all the losses to calculate the average loss
        num_iterations = 0  # Keep track of the iterations in this epoch
        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')

        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data # 'data' is a batch (x, y), where x is the input (image), and y is the label (digit)
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = model(x.view(-1, 28*28))
            loss = loss_function(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            # If a total iteration limit is set, stop training once the limit is reached
            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

train(train_loader, model, epochs= 1)

Epoch 1: 100%|█████████████████████████████████████████████████████████| 6000/6000 [07:15<00:00, 13.77it/s, loss=0.277]


### We keep the original weight of the model befor fine-tuning. 

In [5]:
o_weights = {}
for name, param in model.named_parameters():
    o_weights[name] = param.clone().detach()

### 📊 Model Evaluation

After training the model, it is crucial to evaluate its performance on the test dataset to understand its effectiveness in classifying the MNIST digits. In this section, we will implement a testing function to calculate accuracy and analyze misclassifications.

#### 🔑 Evaluation Objectives:
1. **Accuracy Calculation**:
   - Determine the percentage of correctly classified instances out of the total number of test samples.
   
2. **Misclassification Analysis**:
   - Count and report the number of incorrect predictions for each digit (0-9) to identify specific weaknesses in the model.

#### 🧪 Testing Function:
- **No Gradient Tracking**: We utilize `torch.no_grad()` to disable gradient calculation, which reduces memory usage and speeds up computation during testing.
- **Data Iteration**: The test data is iterated over in batches, similar to training.
- **Output Comparison**:
  - The predicted digit is compared to the actual label to count correct and incorrect classifications.

#### 📈 Performance Metrics:
- **Accuracy**: Calculated as the ratio of correct predictions to total predictions.
- **Wrong Counts**: A detailed breakdown of the number of misclassifications for each digit.

This evaluation will provide insights into the model's strengths and areas for improvement, guiding future enhancements and refinements.


In [6]:
def test():
    correct = 0
    total = 0

    wrong_counts = [0 for i in range(10)]

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                else:
                    wrong_counts[y[idx]] +=1
                total +=1
    print(f'Accuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_counts)):
        print(f'wrong counts for the digit {i}: {wrong_counts[i]}')

test()

Testing: 100%|████████████████████████████████████████████████████████████████████| 1000/1000 [00:05<00:00, 176.63it/s]

Accuracy: 0.948
wrong counts for the digit 0: 18
wrong counts for the digit 1: 19
wrong counts for the digit 2: 30
wrong counts for the digit 3: 81
wrong counts for the digit 4: 42
wrong counts for the digit 5: 11
wrong counts for the digit 6: 60
wrong counts for the digit 7: 30
wrong counts for the digit 8: 150
wrong counts for the digit 9: 81





### 📝 Model Evaluation Results

After conducting the testing phase, the model's performance yielded the following results:

#### 🔍 Accuracy
- **Overall Accuracy**: 94.8%
  
This indicates that the model is performing well, correctly classifying the vast majority of the digits in the MNIST test set.

#### 📉 Misclassification Analysis
The breakdown of misclassified digits reveals specific areas where the model struggles:

- **Digit 0**: 18 misclassifications
- **Digit 1**: 19 misclassifications
- **Digit 2**: 30 misclassifications
- **Digit 3**: 81 misclassifications
- **Digit 4**: 42 misclassifications
- **Digit 5**: 11 misclassifications
- **Digit 6**: 60 misclassifications
- **Digit 7**: 30 misclassifications
- **Digit 8**: 150 misclassifications
- **Digit 9**: 81 misclassifications

#### 📊 Insights
From the analysis, we can see that the model struggles significantly with the digits **3**, **6**, **8**, and **9**, each having over **50** misclassifications. This suggests that further fine-tuning is necessary to improve the model's performance on these specific digits.

### 🎯 Next Steps
To enhance the model's ability to classify these challenging digits, we will focus on fine-tuning the model. This adjustment aims to address the misclassification issues and improve overall performance, particularly in a real-world application where accuracy is critical.



In [7]:
# Print the size of the weights matrices of the network
# Save the count of the total number of parameters
total_parameters_original = 0
for index, layer in enumerate([model.linear1, model.linear2, model.linear3,model.linear4]):
    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')
print(f'Total number of parameters: {total_parameters_original:,}')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000])
Layer 2: W: torch.Size([1500, 1000]) + B: torch.Size([1500])
Layer 3: W: torch.Size([1300, 1500]) + B: torch.Size([1300])
Layer 4: W: torch.Size([10, 1300]) + B: torch.Size([10])
Total number of parameters: 4,250,810


#### 📊 Total Parameter Count
The total number of parameters in the original artificial neural network is **4,250,810**.

### 🎯 Implications
This substantial number of parameters indicates the model's capacity for learning complex patterns, though it may also pose challenges such as overfitting. Understanding this count is crucial as we proceed to implement LoRA to introduce additional parameters while aiming to maintain or improve model performance.


### 📈 LoRA Parameterization Class

The **LoRAParametrization** class implements the Low-Rank Adaptation (LoRA) method, enhancing our neural network's adaptability while keeping the number of trainable parameters in check. Below are the key components and functionalities:

#### 🛠️ Initialization
- **Parameters**:
  - **lora_A**: Initialized with random Gaussian values, representing one part of the low-rank adaptation.
  - **lora_B**: Initialized to zero, contributing to the learning mechanism of the network.
- **Scaling Factor**: 
  - The scaling factor is determined as $$ \frac{\alpha}{r} $$, where \( \alpha \) is a constant and \( r \) is the rank. This aids in stabilizing the training process by reducing the need for extensive hyperparameter tuning.

#### 🔄 Forward Method
- The **forward** method computes the adapted weights as follows:
  - It adds the original weights to the product of \( B \) and \( A \), scaled appropriately.
  - When the adaptation is disabled, it simply returns the original weights, allowing for flexibility during training.

### 🌟 Purpose
The LoRA function's design allows the network to learn additional representations efficiently, facilitating multi-task learning and improving performance on specific tasks without overwhelming the model with a vast number of parameters. This adaptability is particularly beneficial for fine-tuning on selected digits in the MNIST dataset.


In [8]:
class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device=device):
        super().__init__()
        # Section 4.1 of the paper: 
        #   We use a random Gaussian initialization for A and zero for B, so ∆W = BA is zero at the beginning of training
        self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)
        
        # Section 4.1 of the paper: 
        #   We then scale ∆Wx by α/r , where α is a constant in r. 
        #   When optimizing with Adam, tuning α is roughly the same as tuning the learning rate if we scale the initialization appropriately. 
        #   As a result, we simply set α to the first r we try and do not tune it. 
        #   This scaling helps to reduce the need to retune hyperparameters when we vary r.
        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            # Return W + (B*A)*scale
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

### 🎯 Purpose of the LoRA Parameterization Code

The provided code implements the Low-Rank Adaptation (LoRA) technique within a neural network model by applying parameterization to the weight matrices of the linear layers. This approach is specifically designed to enhance the model's adaptability and efficiency during fine-tuning for downstream tasks. Below are the key objectives of this implementation:

1. **Parameterization of Linear Layers**:
   - The code defines a function `linear_layer_parameterization` that only applies LoRA to the weight matrices of the linear layers, intentionally ignoring the biases. This selective adaptation is rooted in the observation that adapting only the attention weights can simplify the training process while maintaining parameter efficiency.

2. **Registration of Parameterizations**:
   - Each linear layer (`linear1`, `linear2`, `linear3`, `linear4`) in the model is registered for LoRA parameterization. This enables the model to utilize the benefits of LoRA, allowing it to adapt its weights effectively while minimizing the number of additional parameters introduced during training.

3. **Toggle LoRA Adaptation**:
   - The function `enable_disable_lora` allows for easy enabling or disabling of the LoRA adaptation across all specified layers. This feature facilitates experimentation, enabling researchers to compare the performance of the model with and without LoRA.

4. **Focus on Efficiency**:
   - By limiting the adaptation to the attention weights and freezing the MLP modules, the implementation aligns with the goal of reducing computational overhead and simplifying the model's architecture. This efficiency is especially beneficial in scenarios where training resources are limited or where rapid adaptation to new tasks is required.

In summary, this code serves to implement a flexible, efficient mechanism for adapting neural network weights, optimizing performance for specific tasks without the need for retraining the entire model.


In [9]:
import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    # Only add the parameterization to the weight matrix, ignore the Bias

    # From section 4.2 of the paper:
    #   We limit our study to only adapting the attention weights for downstream tasks and freeze the MLP modules (so they are not trained in downstream tasks) both for simplicity and parameter-efficiency.
    #   [...]
    #   We leave the empirical investigation of [...], and biases to a future work.
    
    features_in, features_out = layer.weight.shape
    return LoRAParametrization(
        features_in, features_out,
        rank=rank, 
        alpha=lora_alpha, 
        device=device
    )

parametrize.register_parametrization(
    model.linear1, 
    "weight", 
    linear_layer_parameterization(model.linear1, device)
)
parametrize.register_parametrization(
    model.linear2, 
    "weight", 
    linear_layer_parameterization(model.linear2, device)
)
parametrize.register_parametrization(
    model.linear3, 
    "weight", 
    linear_layer_parameterization(model.linear3, device)
)
parametrize.register_parametrization(
    model.linear4, 
    "weight", 
    linear_layer_parameterization(model.linear4, device)
)


def enable_disable_lora(enabled=True):
    for layer in [model.linear1, model.linear2, model.linear3, model.linear4]:
        layer.parametrizations["weight"][0].enabled = enabled

### Number of parameters add by LoRA

In [10]:
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([model.linear1, model.linear2, model.linear3,model.linear4]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
    )
# The non-LoRA parameters count must match the original network
assert total_parameters_non_lora == total_parameters_original
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000]) + Lora_A: torch.Size([1, 784]) + Lora_B: torch.Size([1000, 1])
Layer 2: W: torch.Size([1500, 1000]) + B: torch.Size([1500]) + Lora_A: torch.Size([1, 1000]) + Lora_B: torch.Size([1500, 1])
Layer 3: W: torch.Size([1300, 1500]) + B: torch.Size([1300]) + Lora_A: torch.Size([1, 1500]) + Lora_B: torch.Size([1300, 1])
Layer 4: W: torch.Size([10, 1300]) + B: torch.Size([10]) + Lora_A: torch.Size([1, 1300]) + Lora_B: torch.Size([10, 1])
Total number of parameters (original): 4,250,810
Total number of parameters (original + LoRA): 4,259,204
Parameters introduced by LoRA: 8,394
Parameters incremment: 0.197%


### 📊 Number of Parameters Added by LoRA

This section evaluates the number of parameters introduced by the Low-Rank Adaptation (LoRA) technique in the neural network model. By comparing the parameters before and after the application of LoRA, we can assess its impact on the model's complexity and efficiency. 

#### Key Objectives:

1. **Parameter Counting**:
   - The code computes the total number of parameters added by LoRA (`Lora_A` and `Lora_B`) for each linear layer in the model.
   - It also counts the original parameters of the model, including both weights and biases.

2. **Comparison with Original Model**:
   - An assertion checks that the count of non-LoRA parameters matches the original parameter count of the network, ensuring that no discrepancies arise from the parameterization process.

3. **Results**:
   - The total number of parameters in the original model is 4,250,810.
   - After applying LoRA, the total parameter count increases to 4,259,204, indicating that the LoRA technique introduces an additional 8,394 parameters.
   - This represents a parameter increment of approximately 0.197%, suggesting that the addition of LoRA parameters is minimal compared to the original model size, highlighting its efficiency in augmenting the model without significant overhead.

This analysis emphasizes the effectiveness of LoRA in enhancing model adaptability while maintaining a relatively low increase in parameter count, thereby preserving computational efficiency.


### 🔒 Freezing Original Parameters for Fine-Tuning with LoRA

In this section, we implement a strategy to enhance the model's performance on specific target digits (3, 6, 8, 9) from the MNIST dataset by freezing the original model parameters while allowing only the LoRA parameters to be trained.

#### Key Steps:

1. **Freezing Non-LoRA Parameters**:
   - All parameters in the model that are not part of the LoRA parameterization are frozen, meaning they will not be updated during the training process.
   - This approach preserves the original learned features of the model while allowing for adaptation through LoRA.

2. **Loading the MNIST Dataset**:
   - The MNIST dataset is loaded, filtered to include only the target digits specified (3, 6, 8, 9).
   - A data loader is created to facilitate batch training, with a batch size of 10 and shuffling enabled.

3. **Training Configuration**:
   - The model is trained exclusively on the digit 9 for a limited number of iterations (100 batches) to test the effectiveness of LoRA in improving model performance on this specific digit.

#### Results:
- The training process is significantly faster, demonstrating the efficiency of fine-tuning only the LoRA parameters while the original parameters remain fixed. This results in reduced computation time and resources, making it an effective approach for targeted improvements in model performance.

This methodology illustrates how freezing certain parameters can optimize the training process, particularly in scenarios where quick adaptations are desired without overhauling the entire model.


In [11]:
# Freeze the non-Lora parameters
for name, param in model.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

# Load the MNIST dataset again

target_digits = [3,6,8,9]
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
include_indices = torch.tensor([target in target_digits for target in mnist_trainset.targets])
mnist_trainset.data = mnist_trainset.data[include_indices]
mnist_trainset.targets = mnist_trainset.targets[include_indices]
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Train the network with LoRA only on the digit 9 and only for 100 batches (hoping that it would improve the performance on the digit 9)
train(train_loader, model, epochs=1, total_iterations_limit=100)

Freezing non-LoRA parameter linear1.bias
Freezing non-LoRA parameter linear1.parametrizations.weight.original
Freezing non-LoRA parameter linear2.bias
Freezing non-LoRA parameter linear2.parametrizations.weight.original
Freezing non-LoRA parameter linear3.bias
Freezing non-LoRA parameter linear3.parametrizations.weight.original
Freezing non-LoRA parameter linear4.bias
Freezing non-LoRA parameter linear4.parametrizations.weight.original


Epoch 1:  99%|███████████████████████████████████████████████████████████▍| 99/100 [00:02<00:00, 36.22it/s, loss=0.182]


### Checking that the origianl weights are frozen 

In [22]:
# Check that the frozen parameters are still unchanged by the finetuning
assert torch.all(model.linear1.parametrizations.weight.original == o_weights['linear1.weight'])
assert torch.all(model.linear2.parametrizations.weight.original == o_weights['linear2.weight'])
assert torch.all(model.linear3.parametrizations.weight.original == o_weights['linear3.weight'])
assert torch.all(model.linear4.parametrizations.weight.original == o_weights['linear4.weight'])

enable_disable_lora(enabled=True)
# The new linear1.weight is obtained by the "forward" function of our LoRA parametrization
# The original weights have been moved to net.linear1.parametrizations.weight.original
# More info here: https://pytorch.org/tutorials/intermediate/parametrizations.html#inspecting-a-parametrized-module
assert torch.equal(model.linear1.weight, model.linear1.parametrizations.weight.original + (model.linear1.parametrizations.weight[0].lora_B @ model.linear1.parametrizations.weight[0].lora_A) * model.linear1.parametrizations.weight[0].scale)

enable_disable_lora(enabled=False)
# If we disable LoRA, the linear1.weight is the original one
assert torch.equal(model.linear1.weight, o_weights['linear1.weight'])

AssertionError: 

### Testing the network with LoRA and Original weights

In [12]:
enable_disable_lora(enabled=True)
test()

Testing: 100%|█████████████████████████████████████████████████████████████████████| 1000/1000 [00:10<00:00, 93.01it/s]

Accuracy: 0.956
wrong counts for the digit 0: 17
wrong counts for the digit 1: 30
wrong counts for the digit 2: 40
wrong counts for the digit 3: 24
wrong counts for the digit 4: 52
wrong counts for the digit 5: 78
wrong counts for the digit 6: 26
wrong counts for the digit 7: 85
wrong counts for the digit 8: 39
wrong counts for the digit 9: 54





In [13]:
enable_disable_lora(enabled=False)
test()

Testing: 100%|████████████████████████████████████████████████████████████████████| 1000/1000 [00:05<00:00, 182.81it/s]

Accuracy: 0.948
wrong counts for the digit 0: 18
wrong counts for the digit 1: 19
wrong counts for the digit 2: 30
wrong counts for the digit 3: 81
wrong counts for the digit 4: 42
wrong counts for the digit 5: 11
wrong counts for the digit 6: 60
wrong counts for the digit 7: 30
wrong counts for the digit 8: 150
wrong counts for the digit 9: 81





### 🧪 Testing the Network with LoRA and Original Weights

In this section, we evaluate the performance of the network using both LoRA-parameterized weights and the original weights to compare their effectiveness in classifying the MNIST digits.

#### Testing Procedure

1. **Enabling LoRA Parameters**:
   - The LoRA parameters are enabled for testing to assess the impact of the adaptations made during training.

2. **Testing with LoRA**:
   - The model is tested on a set of 1,000 samples, and the accuracy is calculated along with the count of misclassifications for each digit.
   - Initial results show a high accuracy of **95.6%** with specific wrong counts for each digit.

3. **Disabling LoRA Parameters**:
   - After the initial test, the LoRA parameters are disabled, and the model is tested again using the original weights.

4. **Testing with Original Weights**:
   - The model is evaluated again on the same set of samples, resulting in an accuracy of **94.8%** and a different distribution of wrong counts across the digits.

#### Results Summary

- **LoRA Enabled**:
  - **Accuracy**: 95.6%
  - Misclassifications are observed across different digits, with notable errors on digits such as 5 and 7.

- **LoRA Disabled** (Original Weights):
  - **Accuracy**: 94.8%
  - The misclassification counts reveal different strengths and weaknesses compared to the LoRA-enabled model, particularly with digits like 8 and 3.

### Conclusion

The introduction of LoRA parameters leads to an overall improvement in accuracy compared to using only the original weights. This suggests that fine-tuning with LoRA can enhance model performance, particularly for targeted classes. Such adaptations can be crucial in scenarios requiring focused improvements in specific digit recognition tasks.
