# 背景
在大语言模型（LLM）抑或是图文生成等方面，都需要大量的GPU显存来支撑。这背景下就诞生出了各种参数高效（Parameter-Efficient）的方法。  
最受欢迎的就是**LoRA：《LoRA: Low-Rank Adaptation of Large Language Models》**。

## 原理
对于一个大矩阵$W_0 \in \mathbb{R}^{n \times m}$对他进行低秩分解得到两个小矩阵$A \in \mathbb{R}^{n \times r}, B \in \mathbb{R}^{r \times m}$
在训练的过程中不改变$W_0$的参数，而去改变$A, B$的参数：
$$W_{new} = W_0 + AB$$
在训练计算的时候：
$$h = W_0x + ABx = (W_0 + AB)x $$
一般来说，AB会使用缩放因子$\frac{\alpha}{\gamma}$进行缩放，我们最终写为：
$$ h = W_0x + ABx = (W_0 + \frac{\alpha}{\gamma}AB)x $$
其中$r << \min{(n,m)}$，甚至可以设置成1.
![LoRA示意图](./LoRA.png)

In [1]:
import torch
import torch
from torch import nn
from torch.nn import functional as F
import math

In [2]:
class LinearLoRALayer(nn.Module):
    def __init__(self, 
                in_features, 
                out_features,
                rank, 
                lora_alpha,
                dropout, 
                merge = False
                ):
        super ().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.rank= rank
        self.lora_alpha = lora_alpha
        self.dropout = dropout
        self.merge = merge
        
        
        self.linear = nn.Linear(in_features, out_features, bias=True)
        # linear shape is (out_features, in_features)
        # input x shape is (batch_size, seq_size, in_features)
        # 计算过程是 x @ W^T (细节实现)，所以weight shape is (out_features, in_features)
        
        if rank > 0:
            self.lora_a = nn.Parameter(torch.randn(out_features, rank)) # 高斯分布
            nn.init.kaiming_uniform_(self.lora_a, a=0.01)
            
            self.lora_b = nn.Parameter(torch.randn(rank, in_features))
            self.scale = self.lora_alpha / rank
            
            # linear需要设置为不可训练
            self.linear.weight.requires_grad = False
            self.linear.bias.requires_gradu = False
            
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()


        # merge 是bool类型，如果true则将lora的权重和linear的权重合并
        # 那么会把 lora_a 和 lora_b 两个小矩阵的参数直接放到 linear.weight 中
        if merge:
            self.merge_weight()

    def merge_weight(self, ):
        if self.merge and self.rank > 0:
            self.linear.weight.data += self.scale * (self.lora_a @ self.lora_b)
    
    def unmerge_weight(self, ):
        if self.rank > 0:
            self.linear.weight.data -= self.scale * (self.lora_a @ self.lora_b)
                
    def forward(self, x):
        if rank > 0:
            output_part1 = self.linear(x)
            output_part2 = self.scale*(x @ (self.lora_a @ self.lora_b).T)
            output = output_part1 + output_part2
        else: 
            output = self.linear(x)
            
        output = self.dropout(output)
        return output
            

In [3]:
# 测试代码
batch_size = 32
seq_len = 120
in_features = 768
out_features = 512
rank = 8
lora_alpha = 16
dropout = 0.1

# Create a test input
x = torch.randn(batch_size, seq_len, in_features)
x.shape

torch.Size([32, 120, 768])

In [4]:
# no-merge
lora_layer = LinearLoRALayer(
    in_features=in_features,
    out_features=out_features,
    rank=rank,
    lora_alpha=lora_alpha,
    dropout=dropout,
    merge=False
)

output = lora_layer(x)
print(f"Output shape(unmerge): {output.shape}") #(b, s, out_features)

Output shape(unmerge): torch.Size([32, 120, 512])


In [5]:
# Test merged mode
lora_layer_merged = LinearLoRALayer(
    in_features=in_features,
    out_features=out_features,
    rank=rank,
    lora_alpha=lora_alpha,
    dropout=dropout,
    merge=True
)

# Forward pass with merged weights
output_merged = lora_layer_merged(x)
print(f"Output shape (merged): {output_merged.shape}")

Output shape (merged): torch.Size([32, 120, 512])


In [6]:
# Test weight merging/unmerging
lora_layer.merge_weight()
output_after_merge = lora_layer(x)
lora_layer.unmerge_weight()
output_after_unmerge = lora_layer(x)

print("Max difference after merge/unmerge cycle:", 
      torch.max(torch.abs(output - output_after_unmerge)).item())

Max difference after merge/unmerge cycle: 614.7224731445312
