# LoRA Linear Layer

In [None]:
import torch
import torch.nn as nn
import math

class LoraLinear(nn.Module):
    def __init__(self, in_dim, out_dim, r, alpha,bias=True):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.r = r
        self.alpha = alpha    
        self.scale = self.alpha / self.r
        
        self.linear = nn.Linear(in_dim, out_dim, bias=bias)
        self.lora_a = nn.Linear(in_dim, r, bias=False)
        self.lora_b = nn.Linear(r, out_dim, bias=False)
        self._init_weights()
        
        # 冻结原始权重
        self.linear.weight.requires_grad = False
        if self.linear.bias is not None:
            self.linear.bias.requires_grad = False

    def _init_weights(self):
        nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5))
        nn.init.zeros_(self.lora_b.weight)

    def forward(self, x):
        original_output = self.linear(x)
        lora_output = self.lora_b(self.lora_a(x)) * self.scale
        return original_output + lora_output