In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [None]:
class LinearLoRALayer(nn.Module):
    def __init__(self, in_features, out_features, rank, lora_alpha, dropout, merge=False):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        self.lora_alpha = lora_alpha
        self.dropout = dropout
        self.merge = merge

        self.linear = nn.Linear(in_features, out_features)
        # linear： weight 的 Shape 是： (out_features, in_features)
        # input x shape 是 (batch_size, seq_len, in_features)
        # 计算过程是： x @ linear.weight.T
        # 所以 weight 的 shape 是 (out_features, in_features)

        if self.rank > 0:
            self.lora_a = nn.Parameter(torch.randn(out_features, rank))
            nn.init.kaiming_uniform_(self.lora_a, a=math.sqrt(5))

            self.lora_b = nn.Parameter(torch.randn(rank, in_features))
            self.scale = lora_alpha / rank
        
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()

        # merge 是 bool 类型， 如果为 True， 则将 lora 的权重和 linear 的权重合并
        # 先空着

        # 训练时，冻结linear的权重，只训练lora的权重
        self.linear.weight.requires_grad = False
        
        if merge:
            self.merge_weight()

    def merge_weight(self, ):
        if self.merge and self.rank > 0:
            # (output_features, rank) @ (rank, in_features) = (output_features, in_features)
            self.linear.weight.data += self.scale * (self.lora_a @ self.lora_b)
        
    def unmerge_weight(self, ):
        if self.merge and self.rank > 0:
            self.linear.weight.data -= self.scale * (self.lora_a @ self.lora_b)

    def forward(self, X):
        # X shape 是 (batch_size, seq_len, in_features)

        if self.rank > 0:
            output_part1 = self.linear(X)
            output_part2 = self.scale * (X @ (self.lora_a @ self.lora_b).T)
            output = output_part1 + output_part2
        else:
            output = self.linear(X)
        output = self.dropout(output)
        return output

In [20]:
# Test the LoRALinear layer
batch_size = 32
seq_len = 128
in_features = 768
out_features = 512
rank = 8
lora_alpha = 16
dropout = 0.1

# Create a test input
x = torch.randn(batch_size, seq_len, in_features)

# Test regular mode (no merge)
lora_layer = LinearLoRALayer(
    in_features=in_features,
    out_features=out_features,
    rank=rank,
    lora_alpha=lora_alpha,
    dropout=dropout,
    merge=False
)

# Forward pass
output = lora_layer(x)
print(f"Output shape (no merge): {output.shape}")  # Should be [batch_size, seq_len, out_features]

# Test merged mode
lora_layer_merged = LinearLoRALayer(
    in_features=in_features,
    out_features=out_features,
    rank=rank,
    lora_alpha=lora_alpha,
    dropout=dropout,
    merge=True
)

# Forward pass with merged weights
output_merged = lora_layer_merged(x)
print(f"Output shape (merged): {output_merged.shape}")  # Should be [batch_size, seq_len, out_features]

# Test weight merging/unmerging
lora_layer.merge_weight()
output_after_merge = lora_layer(x)
lora_layer.unmerge_weight()
output_after_unmerge = lora_layer(x)

print("Max difference after merge/unmerge cycle:", 
      torch.max(torch.abs(output - output_after_unmerge)).item())

Output shape (no merge): torch.Size([32, 128, 512])
Output shape (merged): torch.Size([32, 128, 512])
Max difference after merge/unmerge cycle: 201.7412872314453
