# LoRA From Scratch – Implement Low-Rank Adaptation for LLMs in PyTorch

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [10]:
count_parameters = lambda model: sum(p.numel() for p in model.parameters())
count_trainable_parameters = lambda model: sum(p.numel() for p in model.parameters() if p.requires_grad)

In [6]:
class LoRALayer(nn.Module):
    """
    Implements the Low-Rank Adaptive Linear Transformation (LoRA) layer.

    This layer performs a linear transformation with learnable low-rank factors,
    reducing computational cost and potentially improving model performance.

    Args:
        fan_in: Number of input features (int).
        fan_out: Number of output features (int).
        rank: Rank of the low-rank factors (int).
        alpha: Hyperparameter scaling the output (float).

    Attributes:
        alpha: Hyperparameter scaling the output (float).
        A: Low-rank factor matrix of shape (fan_in, rank) (float tensor).
        B: Low-rank factor matrix of shape (rank, fan_out) (float tensor).

    Inputs:
        x: Input tensor of shape (..., fan_in) (float tensor).

    Outputs:
        Transformed tensor of shape (..., fan_out) (float tensor).
    """

    def __init__(self, fan_in: int, fan_out: int, rank: int, alpha: float):
        super().__init__()
        self.alpha = alpha
        std = torch.tensor(rank, dtype=torch.float) ** -0.5
        self.A = nn.Parameter(torch.randn(fan_in, rank) * std)
        self.B = nn.Parameter(torch.zeros(rank, fan_out))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.alpha * (x @ self.A @ self.B)

In [8]:
linear_layer = nn.Linear(10_000, 10_000)
lora_layer = LoRALayer(10_000, 10_000, 8, 4)
count_parameters(linear_layer), count_parameters(lora_layer) # (10_000 * 10_000 + 10_000), (10_000 * 8 * 2)

(100010000, 160000)

In [11]:
x = torch.randn(50, 10_000)
out_linear = linear_layer(x)
out_lora = lora_layer(x)
out_linear.shape, out_lora.shape # the same shape (50, 10_000)

(torch.Size([50, 10000]), torch.Size([50, 10000]))

In [None]:
class LinearWithLoRA(nn.Module):
    """
    Combines a linear layer with a LoRALayer to perform a combined transformation.

    This module wraps a linear layer and adds a LoRALayer in parallel. The
    output is the sum of the linear layer's output and the LoRALayer's output.

    Args:
        linear_layer: The original linear layer to be wrapped (nn.Module).
        rank: Rank of the low-rank factors in the LoRALayer (int).
        alpha: Hyperparameter scaling the LoRALayer output (float).

    Attributes:
        linear_layer: Original linear layer (frozen, nn.Module).
        alpha: Hyperparameter scaling the LoRALayer output (float).
        lora_layer: LoRALayer instance (nn.Module).

    Inputs:
        x: Input tensor of shape (..., in_features) (float tensor).

    Outputs:
        Transformed tensor of shape (..., out_features) (float tensor).
    """

    def __init__(self, linear_layer: nn.Module, rank: int, alpha: float):
        super().__init__()
        self.linear_layer = linear_layer.requires_grad_(False) # Freeze the weights of the linear layer
        self.alpha = alpha
        self.lora_layer = LoRALayer(linear_layer.in_features, linear_layer.out_features, rank, alpha)

    def forward(self, x):
        return self.linear_layer(x) + self.lora_layer(x)