In [None]:
#| default_exp adapters

# adapters

> Since the training method I am going to check with my technique (ReLoRA) is about using low-rank adapters to make high-rank updates (alongside with a specific quantization method) I will make a LoRA implementation here

## Idea

Basically the idea of low-rank adapters is the following one:

$
Linear(W, x) = W x
$
where $W$ is $(outputFeatures, inputFeatures)$ matrix and $x$ is $(inputFeatures, batchSize)$ matrix

So if W_changed is $W + \Delta W$:

$
Linear(W_changed, x) = W_changed x = (W + \Delta W) x = W x + \Delta W x
$

And if $\Delta W$ is a low-rank matrix than we can represent it as 

$
\Delta W = lora_B lora_A
$

Where $lora_B$ is $(outputFeatures, loraRank)$ matrix and $lora_A$ is $(loraRank, inputFeatures)$ matrix

So

$
Linear(W_changed, x) = W x + (lora_B lora_A) x
$

But since $\Delta W$ matrix itself is relatively big and the matrix multiplication is an associative operation - we can make it the following way:

$
Linear(W_changed, x) = W x + lora_B (lora_A x)
$

## Implementation

In [None]:
#| export
import torch

In [None]:
#| export
class LinearAdapter(torch.nn.Module):
    def reset(self):
        """
        Reset adapter (so it actually will not influence the output)
        """
        raise NotImplementedError()

    def calculate_weight_update(self):
        """
        Calculate $\Delta W$ matrix for the current adapter state
        """
        raise NotImplementedError()

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        """
        Forward pass
        """
        raise NotImplementedError()

In [None]:
#| export
class LoRAAdapter(LinearAdapter):
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 lora_rank: int,
                 device=None):
        super(LoRAAdapter, self).__init__()
        self.lora_a = torch.nn.Parameter(torch.zeros(lora_rank, in_features, device=device))
        self.lora_b = torch.nn.Parameter(torch.zeros(out_features, lora_rank, device=device))
        self.reset()
    
    def reset(self) -> None:
        torch.nn.init.xavier_uniform_(self.lora_a.data)
        torch.nn.init.zeros_(self.lora_b.data)
        self.lora_a.data.requires_grad = True
        self.lora_b.data.requires_grad = True
    
    def calculate_weight_update(self):
        return self.lora_b.matmul(self.lora_a)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return torch.nn.functional.linear(
            torch.nn.functional.linear(input, self.lora_a),
            self.lora_b
        )

In [None]:
#| export
class MergeableLayer(torch.nn.Module):
    def __init__(self, adapter: LinearAdapter) -> None:
        super(MergeableLayer, self).__init__()
        self.adapter = adapter
    
    def merge_adapter(self) -> None:
        raise NotImplementedError()

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()