In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class LinearLoRALayer(nn.Module):
    def __init__(self, in_features, out_features, rank=4, alpha=1.0, dropout=0.0, merge=False):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        self.alpha = alpha
        self.dropout = dropout
        self.merge = merge
        
        self.linear = nn.Linear(in_features, out_features)
        # shape of linear weight: (out_features, in_features)
        # input x shape: (batch_size, seq_len, in_features)
        # calculation: x @ weight.T
        # weight shape: (out_features, in_features)

        if rank > 0:
            self.lora_a = nn.Parameter(torch.randn(out_features, rank))
            # gausian distribution
            nn.init.kaiming_uniform_(self.lora_a, a = math.sqrt(5))

            self.lora_b = nn.Parameter(torch.randn(rank, in_features))
            self.scale = alpha / rank
        
        # Dropout layer
        self.dropout_layer = nn.Dropout(p=dropout) if dropout > 0 else nn.Identity()
        
        if merge:
            self.merge_weight()

    def merge_weight(self, ):
        if self.merge and self.rank > 0:
            # (output_features, rank) @ (rank, in_features)
            self.linear.weight.data += self.scale * (self.lora_a @ self.lora_b)

    def unmerge_weight(self,):
        if self.merge and self.rank > 0:
            self.linear.weight.data -= self.scale * (self.lora_a @ self.lora_b)

    def forward(self, X):
        
        # X shape: (batch_size, seq_len, in_features)

        if self.rank > 0:
            output_part1 = self.linear(X)
            output_part2 = self.scale * (X @ (self.lora_a @ self.lora_b).T)
            output = output_part1 + output_part2
        else:
            output = self.linear(X)
        
        output = self.linear(X)

        return output


In [3]:
# Test the LoRALinear layer
batch_size = 32
seq_len = 128
in_features = 768
out_features = 512
rank = 8
alpha = 16
dropout = 0.1

# Create a test input
x = torch.randn(batch_size, seq_len, in_features)

# Test regular mode (no merge)
lora_layer = LinearLoRALayer(
    in_features=in_features,
    out_features=out_features,
    rank=rank,
    alpha= alpha,
    dropout=dropout,
    merge=False
)

# Forward pass
output = lora_layer(x)
print(f"Output shape (no merge): {output.shape}")  # Should be [batch_size, seq_len, out_features]

# Test merged mode
lora_layer_merged = LinearLoRALayer(
    in_features=in_features,
    out_features=out_features,
    rank=rank,
    alpha=alpha,
    dropout=dropout,
    merge=True
)

# Forward pass with merged weights
output_merged = lora_layer_merged(x)
print(f"Output shape (merged): {output_merged.shape}")  # Should be [batch_size, seq_len, out_features]

# Test weight merging/unmerging
lora_layer.merge_weight()
output_after_merge = lora_layer(x)
lora_layer.unmerge_weight()
output_after_unmerge = lora_layer(x)

print("Max difference after merge/unmerge cycle:", 
      torch.max(torch.abs(output - output_after_unmerge)).item())

Output shape (no merge): torch.Size([32, 128, 512])
Output shape (merged): torch.Size([32, 128, 512])
Max difference after merge/unmerge cycle: 0.0
