In [6]:
import torch
import torch.nn as nn
import math

In [7]:
class LoraLayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha, dropout):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.rank = rank

        self.dropout = nn.Dropout(dropout)
        self.A = nn.Linear(in_dim, rank, bias=False)
        self.B = nn.Linear(rank, out_dim, bias=False)
        self.dropout_layer = nn.Dropout(dropout)
        self.scaling = alpha / rank

    def forward(self, x):
        x = self.A(x)
        x = self.dropout_layer(x)
        x = self.B(x)
        return x * self.scaling


class LoraLinear(nn.Module):
    def __init__(self, linear_layer, rank, alpha, dropout):
        super().__init__()
        self.linear_layer = linear_layer
        self.linear_layer.weight.requires_grad = False

        self.lora_layer = LoraLayer(linear_layer.in_features, linear_layer.out_features, rank, alpha, dropout)

    def forward(self, x):
        return self.linear_layer(x) + self.lora_layer(x)


def replace_linear_with_lora(model, rank, alpha, dropout):
    for name, module in model.named_children():
        if isinstance(module, nn.Linear):
            lora_layer = LoraLinear(module, rank, alpha, dropout)
            setattr(model, name, lora_layer)

        else:
            replace_linear_with_lora(module, rank, alpha, dropout)
    return model
        
        
        
        
        

In [8]:
class TestModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 10)
        self.linear2 = nn.Linear(10, 10)


model = TestModel()
print("original model: ", model)
replace_linear_with_lora(model, 10, 1.0, 0.0)
print("lora model: ", model)


for name, param in model.named_parameters():
    print(name, param.requires_grad)




original model:  TestModel(
  (linear1): Linear(in_features=10, out_features=10, bias=True)
  (linear2): Linear(in_features=10, out_features=10, bias=True)
)
lora model:  TestModel(
  (linear1): LoraLinear(
    (linear_layer): Linear(in_features=10, out_features=10, bias=True)
    (lora_layer): LoraLayer(
      (dropout): Dropout(p=0.0, inplace=False)
      (A): Linear(in_features=10, out_features=10, bias=False)
      (B): Linear(in_features=10, out_features=10, bias=False)
      (dropout_layer): Dropout(p=0.0, inplace=False)
    )
  )
  (linear2): LoraLinear(
    (linear_layer): Linear(in_features=10, out_features=10, bias=True)
    (lora_layer): LoraLayer(
      (dropout): Dropout(p=0.0, inplace=False)
      (A): Linear(in_features=10, out_features=10, bias=False)
      (B): Linear(in_features=10, out_features=10, bias=False)
      (dropout_layer): Dropout(p=0.0, inplace=False)
    )
  )
)
linear1.linear_layer.weight False
linear1.linear_layer.bias True
linear1.lora_layer.A.weight T