LoRA对线性层、卷积层进行微调，通过地址矩阵进行分解，要注意初始化方式
A矩阵使用随机高斯初始化，B使用全零初始化，主要是为了保证在训练的初始阶段，SD的权重能够完全生效

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LinearLoRA(nn.Module):
    def __init__(self, in_features, out_features, rank=4, alpha=None):
        super().__init__()
        self.rank = rank
        self.alpha = alpha

        # 构建低秩矩阵A
        self.down = nn.Linear(in_features, rank)
        # B
        self.up = nn.Linear(rank, out_features)

        # 对两个低秩矩阵进行初始化
        # A进行随机高斯初始化， NOTE 这里的下划线表示原地进行初始化，也就是直接进行权重修改
        # NOTE 1/rank主要有以下作用：
        # 当rank比较小的时候，生成的权重值分分布会比较宽，这样有助于模型在训练初期有更大的探索空间
        # 当rank比较大的时候，标准差会比较小，这样生成的权重值会比较集中，有助于帮助模型快速收敛
        nn.init.normal_(self.down.weight, std=1/rank)
        # B进行全零初始化
        nn.init.zeros_(self.up.weight)
    
    def forward(self, hidden_states):
        down_hidden_states = self.down(hidden_states)
        up_hidden_states = self.up(down_hidden_states)
        if self.alpha is not None:
            up_hidden_states *= self.alpha/ self.rank
        
        return up_hidden_states

scale = 0.1  # LoRA的权重，表示影响程度
ori_hidden_states = torch.randn(1, 196, 768)
lora_linear = LinearLoRA(768, 768)
lora_hidden_states = lora_linear(ori_hidden_states)
output = ori_hidden_states + scale * lora_hidden_states
print(output.shape)

torch.Size([1, 196, 768])


In [1]:
import torch
import torch.nn as nn

class ConvLoRA(nn.Module):
    def __init__(self, in_features, out_features, rank, alpha, kernel_size=(1, 1), stride=(1, 1), padding=0):
        super().__init__()
        self.rank = rank
        self.alpha = alpha
        self.down = nn.Conv2d(in_features, rank, kernel_size, stride, padding)
        self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1,1))

        # 初始化
        nn.init.normal_(self.down.weight, std=1/rank)
        nn.init.zeros_(self.up.weight)

    def forward(self, hidden_states):
        down_hidden_states = self.down(hidden_states)
        up_hidden_states = self.up(down_hidden_states)
        if self.alpha is not None:
            up_hidden_states *=  self.alpha / self.rank
        return up_hidden_states
    
dummy_input = torch.randn(1, 4, 64, 64)
rank = 4
alpha = 4
scale = 0.1
conv_lora = ConvLoRA(4, 4,rank, alpha)
hidden_states = conv_lora(dummy_input)
# 融合
output = dummy_input + scale * hidden_states
print(output.shape)

torch.Size([1, 4, 64, 64])
