In [9]:
#LoRA 实现

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LoraLinear(nn.Module):
    def __init__(self, in_features, out_features, merge, rank=16, lora_alpha=16, dropout_rate=0.5):
        self.in_features = in_features
        self.out_features = out_features
        self.merge = merge
        self.rank = rank
        self.scale = lora_alpha / rank
        
        self.linear = nn.Linear(in_features, out_features)
        
        if rank > 0:
            self.lora_a = nn.Parameter(torch.zeros(out_features, rank))
            self.lora_b = nn.Parameter(torch.zeros(rank, in_features))
            self.scale = lora_alpha / self.rank
            #冻结
            self.linear.weight.requires_grad = False
        
        if dropout_rate > 0:
            self.dropout = nn.Dropout(dropout_rate)
        else:
            self.dropout = nn.Identity() #什么也不做
        
        self.initial_weights()
        
    def initial_weights(self):
        nn.init.kaiming_uniform_(self.lora_a, torch.sqrt(5))
        nn.init.zeros(self.lora_b)
    
    def forward(self, X):
        if self.rank > 0 and self.merge:
            output = F.linear(X, self.linear.weight + self.lora_a @ self.lora_b , self.linear.bias)
            output = self.dropout(output)
        else:
            output = self.linear(X)
        return output
            
