# LoRA Layer

In [3]:
import torch


class LoRALayer(torch.nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.W_a = torch.nn.Parameter(torch.randn(in_dim, rank) * std_dev)
        self.W_b = torch.nn.Parameter(torch.zeros(rank, out_dim))
        self.alpha = alpha

    def forward(self, x):
        x = self.alpha * (x @ self.W_a @ self.W_b)
        return x


class LinearWithLoRA(torch.nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features, rank, alpha
        )

    def forward(self, x):
        return self.linear(x) + self.lora(x)

In [5]:
torch.manual_seed(123)

# a simple linear layer with 10 inputs and 1 output
# requires_grad=False makes it non-trainable
linear_layer = torch.nn.Linear(10, 1)
linear_layer.requires_grad_(False)

# a simple example input
x = torch.rand((1, 10))

linear_layer(x)

tensor([[-0.5745]])

# Replace linear layer with LoRA layer:

In [6]:
lora_layer = LinearWithLoRA(linear=linear_layer, rank=8, alpha=1)
lora_layer(x)

tensor([[-0.5745]], grad_fn=<AddBackward0>)

In [7]:
lora_layer.lora.W_b = torch.nn.Parameter(lora_layer.lora.W_b + 0.01 * x[0])

In [8]:
lora_layer(x)

tensor([[-0.5863, -0.5758, -0.5779, -0.5800, -0.5814, -0.5766, -0.5887, -0.5811,
         -0.5859, -0.5892]], grad_fn=<AddBackward0>)