Implementation of `LorA: Low-Rank Adaptation of Large Language Models` paper in pytorch.

As mentioned in the paper, "LoRA allows us to train some dense layers in a neural network indirctly by optimizing rank decomposition matrices of the dense layers' change during adapttion instead, while keeping the pre-trained weights frozen", as shown in the illustration below.

<div style="text-align: center;"><img src="./public/lora.png"></img></div>

LoRA possess several advantages, summarised from the paper:
* A pre-trained model can be shared and used to build many small LoRA modules for dif-
ferent tasks. We can freeze the shared model and efﬁciently switch tasks by replacing the
matrices A and B in the Figure above, reducing the storage requirement and task-switching over-
head signiﬁcantly.
* LoRA makes training more efﬁcient and lowers the hardware barrier to entry by up to 3
times when using adaptive optimizers since we do not need to calculate the gradients or
maintain the optimizer states for most parameters. Instead, we only optimize the injected,
much smaller low-rank matrices.
* The simple linear design allows us to merge the trainable matrices with the frozen weights
when deployed, introducing no inference latency compared to a fully ﬁne-tuned model, by
construction.

Therefor, LoRA freezes pre-trained model weights and injects trainable rank decomposition matrices into each layer of the transformer. This makes it possible to efficiently fine-tune large langauge models by reducing trainable parameters by a large factor.

In [1]:
import torch
import torch.nn as nn

#### LoRA Theory

LoRA linear layer adds a low-rank decomposition to the pre-trained weight matrix $W \in \Bbb R^{d \times k} $ of the linear layer. It proposes freezing the original weights and injecting low-rank update matrices into each layer

| Symbol      | Meaning | Description |
| ----- | ---------------------- | -------------------------------------------- |
| $k$      | Input Dimension          | The number of features going into the layer.      |
| $d$      | Output  Dimension        | The number of features coming out of the layer.  |
| $r$      | LoRA rank (compression dim) | The intrinsic dimension used to approximate the weight update. It controls the size of the low-rank update matrix $\Delta W = BA$.  |
| $A$      | Down projection        | $r \times k$  |
| $B$      | Up projection        | $d \times r$ |
| $x$      | Input vector        | The vector $\in \Bbb R^{k}$, however, it is passed to the layer as batch of shape $(n, k)$ |
| $BAx$      | Low-rank update        | $\in \Bbb R^{d}$ |


If a layer's weight is $W_{0} \in \Bbb R^{d \times k} $, LoRA represents the weight update as $W_{0} + \Delta W = W_{0} + BA $, where:
* $B  \in  \Bbb R^{d \times r}$
* $A  \in  \Bbb R^{r \times k}$
* Both $B$ and $A$ have much smaller inner dimension $r << min(d, k)$.

During training, $W_{0}$ stays frozen, and only A,B are learned. The forward pass througgh this adapted layer is them $h = W_{0}x + (BA)x$, often scaled by a factor $\frac{\alpha}{r}$. for stability. This effectively adds a small "change" matrix BA to the base layer's output without modifying $W_{0}$.

Standard `Linear` dense layers with $W_{0} \in \Bbb R^{d \times k} $ and input $x \in \Bbb R^{k} $ computes $W_{0}x + bias$. LoRA views this weight matrix as full-rank (rank $min(𝑑,𝑘)$) in general. And instead of fine-tunning $W_{0}$, it adds two smaller matrices, $B$ and $A$ whose product $BA$ is rank-$r$. Both $B$ and $A$ are multiplied by the same input $x$ (first $Ax \in \Bbb R^{r}$, then $B(Ax) \in \Bbb R^{d}$) and summed with $W_{0}x$.

By choosing $r$ small (even 1-4 for large layers), the number of trainable parameters drops dramatically, yet $W_{0} + BA$ still has a dimension $d x k$ and affects the layer's output.

#### LoRA Linear Layer Implementation

In [115]:
class Linear(nn.Module):
    def __init__(self, in_features: int, out_features: int, bias: bool, r: int, alpha: int = None):
        super().__init__()

        # Notes:
        # x shape         : (batch_size, k)
        # A.T shape       : (k, r)
        # x @ A.T         : (batch_size, r)     == Ax
        # B.T shape       : (r, d_out)
        # (Ax) @ B.T      : (batch_size, d_out) == B(Ax)

        
        # Set α=r is not provided 
        # i.e. make the scaling factor alpha/r =1 as initially set alpha to the first r and we do not tune it.
        if alpha is None:
            alpha = r

        # Initial Weight Frozen
        self.weight = nn.Parameter(torch.empty(size=(out_features, in_features))) # W0
        self.weight.requires_grad = False

        if bias:
            self.bias = nn.Parameter(torch.empty(out_features)) # or torch.empty((out_features,))
            self.bias.requires_grad = False
        else:
            self.bias = None

        # scaling delta W by alpha/r as in the paper
        self.scaling = alpha / r

        self.lora_a = nn.Parameter(torch.empty(size=(r, in_features)))
        self.lora_b = nn.Parameter(torch.empty(size=(out_features, r)))

        # From the paper: 
        # "We use a random Gaussian initialization for A and zero for B, so ∆W = BA is zero at the beginning of training"
        with torch.no_grad():
            nn.init.kaiming_uniform_(self.lora_a, a=5 ** 0.5)
            nn.init.zeros_(self.lora_b)

    def forward(self, x: torch.Tensor):
        output = nn.functional.linear(x, self.weight, bias=self.bias) # W0
        output += (x @ self.lora_a.T @ self.lora_b.T) * self.scaling

        return output

In [120]:
layer = Linear(
    in_features = 10, # k
    out_features = 10, # d
    bias = True, 
    r = 4,
    alpha = None
)

x = torch.ones((8, 10)) # n, k
y = layer(x)

In [123]:
print(y.shape)
assert x.shape == y.shape, "Shapes not matching"

torch.Size([8, 10])


Only `lora_a` and `lora_b` must be trainable.

In [122]:
list(layer.named_parameters())

[('weight',
  Parameter containing:
  tensor([[-4.3908e+06,  4.5783e-41, -4.3908e+06,  4.5783e-41, -2.2521e-16,
            4.5782e-41, -2.2518e-16,  4.5782e-41, -2.2526e-16,  4.5782e-41],
          [-2.2512e-16,  4.5782e-41, -2.2381e-16,  4.5782e-41, -2.2346e-16,
            4.5782e-41, -3.7146e-19,  4.5782e-41, -2.3118e-16,  4.5782e-41],
          [-2.2222e-16,  4.5782e-41, -2.2413e-16,  4.5782e-41, -3.7177e-19,
            4.5782e-41, -2.0052e-16,  4.5782e-41, -2.2522e-16,  4.5782e-41],
          [-2.2514e-16,  4.5782e-41, -2.2540e-16,  4.5782e-41, -2.2527e-16,
            4.5782e-41, -2.2529e-16,  4.5782e-41, -2.2538e-16,  4.5782e-41],
          [-2.2549e-16,  4.5782e-41, -2.2544e-16,  4.5782e-41, -2.1103e-16,
            4.5782e-41, -2.3273e-16,  4.5782e-41, -2.3274e-16,  4.5782e-41],
          [-2.3273e-16,  4.5782e-41, -2.2394e-16,  4.5782e-41, -2.2395e-16,
            4.5782e-41, -2.2393e-16,  4.5782e-41, -2.2338e-16,  4.5782e-41],
          [ 3.5873e-43,  0.0000e+00,  3.8115e-