In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [26]:
torch.manual_seed(-1)

<torch._C.Generator at 0x11c9565b0>

In [27]:
class LoRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
        self.B = nn.Parameter(torch.zeros(rank, out_dim))
        self.alpha = alpha

    def forward(self, x):
        x = self.alpha * (x @ self.A @ self.B)
        return x

In [28]:
class LinearWithLora(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(self.linear.in_features, self.linear.out_features, rank, alpha)

    def forward(self, x):
        return self.linear(x) + self.lora(x)

In [29]:
layer = nn.Linear(10, 2)

In [30]:
x = torch.randn(1, 10)

In [31]:
layer(x)

tensor([[-1.5018, -1.0002]], grad_fn=<AddmmBackward0>)

In [32]:
lora_layer = LinearWithLora(layer, rank=2, alpha=4)

In [33]:
lora_layer(x)

tensor([[-1.5018, -1.0002]], grad_fn=<AddBackward0>)

In [36]:
class LinearWithLoraMerged(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)

    def forward(self, x):
        lora = self.lora.A @ self.lora.B
        combined_weight = self.linear.weight + self.lora.alpha * lora.T
        return F.linear(x, combined_weight, self.linear.bias)

In [37]:
merged_lora = LinearWithLoraMerged(layer, 2, 4)

In [38]:
merged_lora(x)

tensor([[-1.5018, -1.0002]], grad_fn=<AddmmBackward0>)

In [11]:
class MultiLayerPerceptron(nn.Module):
    def __init__(self, num_features, num_hidden_1, num_hidden_2, num_classes):
        super().__init__()
        self.layers = nn.Sequential(nn.Linear(num_features, num_hidden_1), 
                                   nn.ReLU(), 
                                   nn.Linear(num_hidden_1, num_hidden_2), 
                                   nn.ReLU(), 
                                   nn.Linear(num_hidden_2, num_classes))
        def forward(self, x):
            return self.layers(x)

In [12]:
model = MultiLayerPerceptron(768, 128, 256, 10)

In [13]:
model

MultiLayerPerceptron(
  (layers): Sequential(
    (0): Linear(in_features=768, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=10, bias=True)
  )
)

In [18]:
# model.layers[0] = LinearWithLora(model.layers[0], rank=4, alpha=8)
model.layers[2] = LinearWithLora(model.layers[2], rank=4, alpha=8)
model.layers[4] = LinearWithLora(model.layers[4], rank=4, alpha=8)

In [19]:
model

MultiLayerPerceptron(
  (layers): Sequential(
    (0): LinearWithLora(
      (linear): Linear(in_features=768, out_features=128, bias=True)
      (lora): LoRALayer()
    )
    (1): ReLU()
    (2): LinearWithLora(
      (linear): Linear(in_features=128, out_features=256, bias=True)
      (lora): LoRALayer()
    )
    (3): ReLU()
    (4): LinearWithLora(
      (linear): Linear(in_features=256, out_features=10, bias=True)
      (lora): LoRALayer()
    )
  )
)

In [20]:
def freeze_linear_layers(model):
    for module in model.children():
        if isinstance(module, nn.Linear):
            for param in module.parameters():
                param.requires_grad = False
        else:
            freeze_linear_layers(module)

In [21]:
freeze_linear_layers(model)

In [22]:
for name, param in model.named_parameters():
    print(f"{name}: {param.requires_grad}")

layers.0.linear.weight: False
layers.0.linear.bias: False
layers.0.lora.A: True
layers.0.lora.B: True
layers.2.linear.weight: False
layers.2.linear.bias: False
layers.2.lora.A: True
layers.2.lora.B: True
layers.4.linear.weight: False
layers.4.linear.bias: False
layers.4.lora.A: True
layers.4.lora.B: True
