# **Singular Value Decomposition**

In [None]:
import torch
import numpy as np
torch.manual_seed(42)

<torch._C.Generator at 0x7cf5e47a4ff0>

## Generate Rank deficient Matrix

In [None]:
import torch

def generate_rank_deficient_matrix(shape, rank):
    """
    Generates a rank-deficient matrix of a given shape and rank.

    Args:
        shape (tuple): Shape of the matrix as (rows, cols).
        rank (int): Desired rank of the matrix.

    Returns:
        torch.Tensor: A rank-deficient matrix with the specified rank.
    """
    rows, cols = shape

    # Ensure rank is valid (it can't be higher than the minimum of rows or columns)
    if rank > min(rows, cols):
        raise ValueError("Rank cannot be greater than the minimum of the number of rows or columns.")

    # Step 1: Create a random matrix of the specified rank
    # Create two random matrices and multiply them to create a matrix of the desired rank
    A = torch.randn(rows, rank)  # (rows, rank)
    B = torch.randn(rank, cols)  # (rank, cols)

    # Step 2: Multiply them to get a matrix of shape (rows, cols) with rank `rank`
    rank_deficient_matrix = A @ B

    # Check the rank of the generated matrix
    generated_rank = torch.linalg.matrix_rank(rank_deficient_matrix)

    assert generated_rank == rank, f"Generated matrix does not have the desired rank: {generated_rank} != {rank}"

    return rank_deficient_matrix, generated_rank

In [None]:
# Example usage
m, n = 20, 20
rank = 2
W, W_rank = generate_rank_deficient_matrix((m, n), rank)
print("Rank of the Matrix:", W_rank.item())

Rank of the Matrix: 2


## Apply SVD on generated matrix (W)

In [None]:
# Perform SVD on W (W = U x D x V^T)
U, D, V = torch.svd(W)

# Select only W_rank singular values & vectors
U_r = U[:, :W_rank]
D_r = D[:W_rank]
V_r = V[:, :W_rank].t()

# Compute B = U_r * D_r and A = V_r
A = V_r
B = U_r * D_r
print(f'Shape of A: {A.shape}')
print(f'Shape of B: {B.shape}')

Shape of A: torch.Size([2, 20])
Shape of B: torch.Size([20, 2])


## Check the Difference between output

In [None]:
# Generate random input and bias
x = torch.randn(m)
b = torch.randn(m)

# Compute y = Wx + b
y = W @ x + b

# Compute y' = (B @ A) @ x + b
y_r = (B @ A) @ x + b

print("Original y using W:\n", y)
print("")
print("y' computed using BA:\n", y_r)

Original y using W:
 tensor([ 7.6824, -3.4827, -2.3524, -3.8416,  3.4325, -4.4672, -2.2232,  1.4218,
         1.5892,  0.1178,  0.4996,  5.2467,  1.8231,  0.4579,  3.7740, -2.8144,
        -1.8249,  4.5885,  0.9263, -5.0501])

y' computed using BA:
 tensor([ 7.6824, -3.4827, -2.3524, -3.8416,  3.4325, -4.4672, -2.2232,  1.4218,
         1.5892,  0.1178,  0.4996,  5.2467,  1.8231,  0.4579,  3.7740, -2.8144,
        -1.8249,  4.5885,  0.9263, -5.0501])


In [None]:
print("Total parameters of W: ", W.nelement())
print("Total parameters of B and A: ", B.nelement() + A.nelement())

Total parameters of W:  400
Total parameters of B and A:  80


# LORA Implementation

*https://github.com/pytorch/examples/blob/main/mnist/main.py*

In [None]:
import torch
import torch.nn as nn
from tqdm import tqdm
import matplotlib.pyplot as plt
import torchvision.datasets as datasets
import torchvision.transforms as transforms

torch.manual_seed(42)

<torch._C.Generator at 0x7cf5e47a4ff0>

## Load Dataset

In [None]:
transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])

train_kwargs = {'batch_size': 512, 'shuffle': True}
test_kwargs = {'batch_size': 512, 'shuffle': True}

# Load MNIST train and test dataset
dataset1 = datasets.MNIST('../data', train=True, download=True, transform=transform)
dataset2 = datasets.MNIST('../data', train=False, transform=transform)

# Create dataloader for training and testing
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

In [None]:
# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Create Neural Network

In [None]:
# Create neural network to classifiy the MNIST dataset, make it with more parameters to represent lora much better

class ComplicatedNetwork(nn.Module):
    def __init__(self, hidden_layer_1=1000, hidden_layer_2=2000):
        super(ComplicatedNetwork, self).__init__()
        self.fc1 = nn.Linear(28*28, hidden_layer_1)
        self.fc2 = nn.Linear(hidden_layer_1, hidden_layer_2)
        self.fc3 = nn.Linear(hidden_layer_2, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
      x = x.view(-1, 28*28)
      x = self.relu(self.fc1(x))
      x = self.relu(self.fc2(x))
      x = self.fc3(x)
      return x

cnet = ComplicatedNetwork().to(device)

## Train

In [None]:
def train(train_loader, model, epochs=10, total_iterations_limit=None):
  loss = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

  total_iterations = 0
  for epoch in range(epochs):
    model.train()

    loss_sum = 0
    num_iterations = 0

    data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')

    if total_iterations_limit is not None:
      data_iterator.total = total_iterations_limit

    for data, target in data_iterator:
      num_iterations += 1
      total_iterations += 1

      data, target = data.to(device), target.to(device)

      optimizer.zero_grad()
      output = model(data.view(-1, 28*28))
      loss_value = loss(output, target)
      loss_sum += loss_value.item()

      avg_loss = loss_sum / num_iterations
      data_iterator.set_postfix(loss=avg_loss)

      loss_value.backward()
      optimizer.step()

      if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
        return


In [None]:
train(train_loader, cnet, epochs=1)

Epoch 1: 100%|██████████| 118/118 [00:40<00:00,  2.90it/s, loss=0.258]


In [None]:
original_weights = {}
for name, param in cnet.named_parameters():
    original_weights[name] = param.clone().detach()

## Testing

In [None]:
def infer():
  correct = 0
  total = 0

  wrong_count = [0 for i in range(10)]

  with torch.no_grad():
    for data, target in test_loader:
      data, target = data.to(device), target.to(device)
      output = cnet(data.view(-1, 28*28))

      for idx, i in enumerate(output):
        if torch.argmax(i) == target[idx]:
          correct += 1
        else:
          wrong_count[target[idx]] += 1
        total += 1
    print(f'Accuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_count)):
      print(f'Wrong predictions for class {i}: {wrong_count[i]}')

infer()

Accuracy: 0.965
Wrong predictions for class 0: 17
Wrong predictions for class 1: 16
Wrong predictions for class 2: 30
Wrong predictions for class 3: 32
Wrong predictions for class 4: 15
Wrong predictions for class 5: 35
Wrong predictions for class 6: 29
Wrong predictions for class 7: 40
Wrong predictions for class 8: 39
Wrong predictions for class 9: 97


## Visualize total parameters in network

In [None]:
# total_parameters_original = sum(p.numel() for p in cnet.parameters())
# total_parameters_original

In [None]:
total_parameters_original = 0

for index, layer in enumerate([cnet.fc1, cnet.fc2, cnet.fc3]):
  total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
  print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')

print(f'Total number of parameters: {total_parameters_original:,}')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10])
Total number of parameters: 2,807,010


## LoRA Parameters

In [None]:
class LoRANetwork(nn.Module):
  def __init__(self, feature_in, feature_out, rank=1, alpha=1, device='cpu'):
    super().__init__()
    self.lora_A = nn.Parameter(torch.zeros(rank, feature_out)).to(device)
    self.lora_B = nn.Parameter(torch.zeros(feature_in, rank)).to(device)
    nn.init.normal_(self.lora_A, mean=0, std=1)

    self.scale = alpha / rank
    self.enabled = True

  def forward(self, original_weights):
    if self.enabled:
      return original_weights + (self.lora_B @ self.lora_A).view(original_weights.shape) * self.scale
    else:
      return original_weights


The line `import torch.nn.utils.parametrize as parametrize` is importing the `parametrize` module from `torch.nn.utils` in PyTorch. The `parametrize` module is a feature that allows you to apply constraints, transformations, or reparameterizations to `torch.nn.Module` parameters.

### What Does `parametrize` Do?

In PyTorch, `parametrize` provides a way to register and manage "parameterizations" of a module's parameters. A parameterization is a transformation applied to a parameter of a neural network layer before it is used in computations. For example, if you want to constrain a weight matrix to be symmetric or positive definite, you can use `parametrize` to enforce this constraint.

### Common Use Cases

1. **Enforcing Constraints:** You can use parameterizations to enforce certain constraints, like making a parameter non-negative or symmetric.
2. **Custom Transformations:** It allows you to apply custom transformations to parameters, such as normalizing or projecting them into a specific space.
3. **Reparameterization Tricks:** Useful in variational inference or other cases where you need to reparameterize your model for optimization.

### How Does It Work?

1. **Register a Parameterization:**
   - You can register a parameterization for a parameter in a `torch.nn.Module`. The parameterization is applied every time the parameter is accessed.
2. **Parameterized Attribute:**
   - The original parameter is replaced with a new attribute, and the parameterized version is computed using the registered transformation.

### Example

Here is a basic example of how you might use `parametrize` to enforce a weight matrix in a linear layer to be non-negative:

```python
import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize

# Define a custom parameterization
class PositiveParametrize(torch.nn.Module):
    def forward(self, X):
        return torch.abs(X)

# Create a simple linear layer
linear = nn.Linear(10, 5)

# Apply the parameterization to the weight parameter of the linear layer
parametrize.register_parametrization(linear, "weight", PositiveParametrize())

# Now, linear.weight will always be non-negative
print(linear.weight)
```

In this example:
- We define a custom parameterization that ensures the weights are non-negative by applying the `torch.abs()` function.
- We register this parameterization with the `linear` layer, so the weight matrix is always transformed to be non-negative whenever it is accessed.

### Summary

`parametrize` in PyTorch provides a flexible way to apply constraints or transformations to module parameters. It is particularly useful for adding custom behavior to parameters while keeping the code clean and modular.

The code you provided is applying a custom parameterization to the weights of several linear layers in a neural network using the `parametrize` module in PyTorch. The goal is to replace the original weights with a parameterized version, likely based on a Low-Rank Adaptation (LoRA) approach.

### Breakdown of the Code

1. **Import the Parametrize Module:**
   ```python
   import torch.nn.utils.parametrize as parametrize
   ```
   This line imports the `parametrize` module, which allows you to register parameterizations for a neural network's parameters.

2. **Define a Function for Linear Layer Parameterization:**
   ```python
   def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
       feature_in, feature_out = layer.weight.shape
       lora_layer = LoRANetwork(feature_in, feature_out, rank, lora_alpha, device)
       return lora_layer
   ```
   - The function `linear_layer_parameterization` takes in a layer (e.g., a linear layer), a device (CPU or GPU), and optional parameters `rank` and `lora_alpha`.
   - It extracts the input and output feature dimensions from the weight shape of the given layer.
   - It then creates a `LoRANetwork` object, which is likely a custom class you have defined for applying Low-Rank Adaptation to the weights, using the specified rank and alpha values.
   - This function returns the `LoRANetwork` object that will be used to replace the original weights with a parameterized version.

3. **Register Parameterizations for Different Layers:**
   ```python
   parametrize.register_parametrization(cnet.fc1,
         "weight", linear_layer_parameterization(cnet.fc1, device))
   parametrize.register_parametrization(cnet.fc2,
         "weight", linear_layer_parameterization(cnet.fc2, device))
   parametrize.register_parametrization(cnet.fc3,
         "weight", linear_layer_parameterization(cnet.fc3, device))
   ```
   - These lines use the `parametrize.register_parametrization` method to replace the weights of layers `fc1`, `fc2`, and `fc3` of the neural network `cnet`.
   - The parameterization is applied to the "weight" parameter of each layer, using the `linear_layer_parameterization` function.
   - This effectively reparameterizes the original weight with the `LoRANetwork` transformation, allowing for efficient low-rank updates.

### What is LoRA?

Low-Rank Adaptation (LoRA) is a technique used to adapt the weights of pre-trained models with a low-rank approximation, which reduces the number of parameters that need to be updated. By adding a low-rank matrix to the original weight, it makes fine-tuning more efficient, especially in large models.

### Explanation of the Workflow

1. **Original weights are replaced with the low-rank approximation** using the `LoRANetwork` parameterization.
2. **The `parametrize` module manages the parameterization**, allowing you to apply transformations while keeping the original model structure.
3. **The registered parameterizations ensure the modified weights are used** whenever the layers are involved in forward passes.

This approach can help fine-tune models with fewer parameters while preserving the structure and efficiency of the original model.

In [None]:
import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):

  feature_in, feature_out = layer.weight.shape
  lora_layer = LoRANetwork(feature_in, feature_out, rank, lora_alpha, device)

  return lora_layer

parametrize.register_parametrization(cnet.fc1,
      "weight", linear_layer_parameterization(cnet.fc1, device))
parametrize.register_parametrization(cnet.fc2,
      "weight", linear_layer_parameterization(cnet.fc2, device))
parametrize.register_parametrization(cnet.fc3,
      "weight", linear_layer_parameterization(cnet.fc3, device))

def enable_disable_lora(enabled=True):
    for layer in [cnet.fc1, cnet.fc2, cnet.fc3]:
        layer.parametrizations["weight"][0].enabled = enabled

The function `enable_disable_lora` allows for toggling the use of the LoRA (Low-Rank Adaptation) parameterization on or off for the weights of the specified layers in the neural network (`cnet`). Here's how it works in detail:

### How the Function Works

1. **Definition of the Function:**
   ```python
   def enable_disable_lora(enabled=True):
       for layer in [cnet.linear1, cnet.linear2, cnet.linear3]:
           layer.parametrizations["weight"][0].enabled = enabled
   ```
   - The function takes a single parameter, `enabled`, which defaults to `True`. This parameter controls whether the LoRA parameterization should be active or not.
   - The function iterates over a list of layers `[cnet.linear1, cnet.linear2, cnet.linear3]`, assuming these are linear layers in the neural network `cnet`.

2. **Accessing the Parameterization:**
   - For each layer, the function accesses the `parametrizations` attribute for the `"weight"` parameter. This attribute is a list, and the first (and only) element is the `LoRANetwork` object.
   - The function then sets the `enabled` attribute of the `LoRANetwork` object to the value of the `enabled` parameter passed to `enable_disable_lora`.

3. **How It Controls the LoRA Behavior:**
   - The `LoRANetwork` class has an `enabled` attribute that determines whether the parameterization is applied in the `forward` method.
   - If `self.enabled` is `True`, the LoRA adaptation `(self.lora_B @ self.lora_A).view(original_weights.shape) * self.scale` is added to the original weights.
   - If `self.enabled` is `False`, the original weights are returned without any modifications.

### What Happens When You Call the Function

- **`enable_disable_lora(True)`**: Enables the LoRA adaptation for the specified layers, so the weights will be adjusted with the low-rank approximation during forward passes.
- **`enable_disable_lora(False)`**: Disables the LoRA adaptation, making the network use the original weights directly without applying the low-rank modification.

### Use Case

This function is useful for turning the LoRA adaptation on or off during different phases of training or evaluation. For example, you might want to:
- **Enable LoRA during fine-tuning** to use the low-rank adaptation for updating the model.
- **Disable LoRA during evaluation** to measure the performance of the model without the low-rank adjustments.

This approach adds flexibility to control whether the model uses the parameterized weights or the original weights based on the current requirement.

In [None]:
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([cnet.fc1, cnet.fc2, cnet.fc3]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
    )
# The non-LoRA parameters count must match the original network
assert total_parameters_non_lora == total_parameters_original
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000]) + Lora_A: torch.Size([1, 784]) + Lora_B: torch.Size([1000, 1])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000]) + Lora_A: torch.Size([1, 1000]) + Lora_B: torch.Size([2000, 1])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10]) + Lora_A: torch.Size([1, 2000]) + Lora_B: torch.Size([10, 1])
Total number of parameters (original): 2,807,010
Total number of parameters (original + LoRA): 2,813,804
Parameters introduced by LoRA: 6,794
Parameters incremment: 0.242%


Freeze all the parameters of the original network and only fine tuning the ones introduced by LoRA. Then fine-tune the model on the digit 9 and only for 100 batches.

In [None]:
# Freeze the non-Lora parameters
for name, param in cnet.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

# Load the MNIST dataset again, by keeping only the digit 8
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 8
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Train the network with LoRA only on the digit 8 and only for 100 batches (hoping that it would improve the performance on the digit 8)
train(train_loader, cnet, epochs=1, total_iterations_limit=10)

Freezing non-LoRA parameter fc1.bias
Freezing non-LoRA parameter fc1.parametrizations.weight.original
Freezing non-LoRA parameter fc2.bias
Freezing non-LoRA parameter fc2.parametrizations.weight.original
Freezing non-LoRA parameter fc3.bias
Freezing non-LoRA parameter fc3.parametrizations.weight.original


Epoch 1:  90%|█████████ | 9/10 [00:00<00:00, 22.10it/s, loss=0.0493]


Verify that the fine-tuning didn't alter the original weights, but only the ones introduced by LoRA.

In [None]:
# Check that the frozen parameters are still unchanged by the finetuning
assert torch.all(cnet.fc1.parametrizations.weight.original == original_weights['fc1.weight'])
assert torch.all(cnet.fc2.parametrizations.weight.original == original_weights['fc2.weight'])
assert torch.all(cnet.fc3.parametrizations.weight.original == original_weights['fc3.weight'])

In [None]:
enable_disable_lora(enabled=True)
# The new fc1.weight is obtained by the "forward" function of our LoRA parametrization
# The original weights have been moved to net.fc1.parametrizations.weight.original
# More info here: https://pytorch.org/tutorials/intermediate/parametrizations.html#inspecting-a-parametrized-module
assert torch.equal(cnet.fc1.weight, cnet.fc1.parametrizations.weight.original + (cnet.fc1.parametrizations.weight[0].lora_B @ cnet.fc1.parametrizations.weight[0].lora_A) * cnet.fc1.parametrizations.weight[0].scale)

enable_disable_lora(enabled=False)
# If we disable LoRA, the fc1.weight is the original one
assert torch.equal(cnet.fc1.weight, original_weights['fc1.weight'])

Test the network with LoRA enabled (the digit 9 should be classified better)

In [None]:
# Test with LoRA enabled
enable_disable_lora(enabled=True)
infer()

Accuracy: 0.931
Wrong predictions for class 0: 19
Wrong predictions for class 1: 41
Wrong predictions for class 2: 54
Wrong predictions for class 3: 113
Wrong predictions for class 4: 23
Wrong predictions for class 5: 85
Wrong predictions for class 6: 36
Wrong predictions for class 7: 84
Wrong predictions for class 8: 16
Wrong predictions for class 9: 223


In [None]:
# Test with LoRA disabled
enable_disable_lora(enabled=False)
infer()

Accuracy: 0.965
Wrong predictions for class 0: 17
Wrong predictions for class 1: 16
Wrong predictions for class 2: 30
Wrong predictions for class 3: 32
Wrong predictions for class 4: 15
Wrong predictions for class 5: 35
Wrong predictions for class 6: 29
Wrong predictions for class 7: 40
Wrong predictions for class 8: 39
Wrong predictions for class 9: 97
