In [15]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class LoRALinear(nn.Module):
    """
    A drop-in Linear with LoRA adaptation.
    - base: a frozen big Linear (weight/bias not trainable)
    - lora_A: (r, in_features)
    - lora_B: (out_features, r)
    forward: x @ W^T + (alpha/r) * x @ A^T @ B^T + b
    """
    def __init__(
        self,
        in_features: int,
        out_features: int,
        r: int = 8,
        alpha: float = 16.0,
        dropout: float = 0.1,
        bias: bool = True,
        init_scale: float = 1e-4, # small init for A, zeros for B (or vice-versa)
    ):
        super().__init__()
        assert r > 0, "LoRA rank r must be > 0"

        self.base = nn.Linear(in_features, out_features, bias=bias)
        for p in self.base.parameters():
            p.requires_grad = False # frozen

        self.r = r
        self.alpha = alpha
        self.scaling = alpha / r
        self.lora_A = nn.Parameter(torch.empty(r, in_features))
        self.lora_B = nn.Parameter(torch.empty(out_features, r))
        # A with small init, B with zeros init
        nn.init.normal_(self.lora_A, std=init_scale)
        nn.init.zeros_(self.lora_B)

        self.dropout = nn.Dropout(dropout)
        self.merged = False # add BA to base.weight when inference

    def forward(self, x):
        if self.merged:
            return self.base(x)

        # forward + LoRA path
        base_out = self.base(x)
        lora_out = F.linear(self.dropout(x), self.lora_A)           # x @ A^T
        lora_out = F.linear(lora_out, self.lora_B) * self.scaling   # (xA^T) @ B^T
        return base_out + lora_out

    @torch.no_grad()
    def merge_weights_(self):
        """add BA to base.weight when inference"""
        if self.merged:
            return
        delta_w = (self.lora_B @ self.lora_A) * self.scaling  # [out, in]
        self.base.weight += delta_w
        self.merged = True

In [20]:
class TinyAdapter(nn.Module):
    """
    two-layer MLP, hidden layer is frozen base Linear,
    second layer is LoRA for task adaptation.
    """
    def __init__(self, d_in=128, d_hidden=256, d_out=16, r=8, alpha=16.0):
        super().__init__()
        # backbone: frozen
        self.backbone = nn.Linear(d_in, d_hidden, bias=True)
        # LoRA: trainable
        self.head = LoRALinear(d_hidden, d_out, r=r, alpha=alpha, bias=True)
        self.act = nn.GELU()

    def forward(self, x):
        x = self.act(self.backbone(x))
        x = self.head(x)
        return x

adapter = TinyAdapter()
x = torch.randn(10, 128)
adapter(x).shape

torch.Size([10, 16])

这里的数据构造相对巧妙，通过两个略有区别的线性映射构造标签。

In [None]:
def make_task_matrices(d_in=128, d_hidden=256, d_out=16, device="cpu"):
    torch.manual_seed(0)
    W_old1 = torch.randn(d_hidden, d_in) / math.sqrt(d_in)
    b_old1 = torch.randn(d_hidden) * 0.1
    W_old2 = torch.randn(d_out, d_hidden) / math.sqrt(d_hidden)
    b_old2 = torch.randn(d_out) * 0.1
    shift = 0.15
    W_new2 = W_old2 + shift * torch.randn_like(W_old2)
    b_new2 = b_old2 + shift * torch.randn_like(b_old2)
    mats = dict(
        W_old1=W_old1.to(device), b_old1=b_old1.to(device),
        W_old2=W_old2.to(device), b_old2=b_old2.to(device),
        W_new2=W_new2.to(device), b_new2=b_new2.to(device),
    )
    return mats

def synth_data(n=4096, d_in=128, device="cpu"):
    x = torch.randn(n, d_in, device=device)
    return x

def forward_with_mats(x, W1, b1, W2, b2):
    h = F.gelu(F.linear(x, W1, b1))
    y = F.linear(h, W2, b2)
    return y

device = "mps" if torch.cuda.is_available() else "cpu"
d_in, d_hidden, d_out = 128, 256, 16

mats = make_task_matrices(d_in, d_hidden, d_out, device=device)
x_train = synth_data(4096, d_in, device=device)
x_val = synth_data(512, d_in, device=device)

# target labels: from old/new task
y_old_train = forward_with_mats(
    x_train, mats["W_old1"], mats["b_old1"], mats["W_old2"], mats["b_old2"]
)
y_new_train = forward_with_mats(
    x_train, mats["W_old1"], mats["b_old1"], mats["W_new2"], mats["b_new2"]
)
y_new_val = forward_with_mats(
    x_val, mats["W_old1"], mats["b_old1"], mats["W_new2"], mats["b_new2"]
)

y_new_train.shape, y_new_val.shape

(torch.Size([4096, 16]), torch.Size([512, 16]))

如下是预训练和后训练的简单模拟，核心在于
- 通过 requires_grad 指定是否需要梯度；
- 为 AdamW 优化器提供 lambda 表达式以指定哪些参数是需要梯度更新的；
- 最后，merged 后 base.weight += delta_w 即可；

In [21]:
r, alpha = 8, 16.0
model = TinyAdapter(d_in, d_hidden, d_out, r=r, alpha=alpha).to(device)

# pretrain
for p in model.backbone.parameters():
    p.requires_grad = True
for p in model.head.parameters():
    p.requires_grad = False

opt = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-3)
loss_fn = nn.MSELoss()

for step in range(800):
    opt.zero_grad()
    pred = model(x_train)
    loss = loss_fn(pred, y_old_train)
    loss.backward()
    opt.step()
    if (step+1) % 100 == 0:
        print(f"[Pretrain] step {step+1}, loss={loss.item():.4f}")

# finetune
for p in model.backbone.parameters():
    p.requires_grad = False
for n, p in model.head.named_parameters():
    p.requires_grad = ("lora_" in n) or ("bias" in n)  # train LoRA A/B and optional bias

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"\nTrainable params after freezing: {trainable}/{total} ({100*trainable/total:.2f}%). only LoRA is trainable.")

opt = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-3)

for step in range(600):
    opt.zero_grad()
    pred = model(x_train)
    loss = loss_fn(pred, y_new_train)
    loss.backward()
    opt.step()
    if (step+1) % 100 == 0:
        with torch.no_grad():
            val_pred = model(x_val)
            val_loss = loss_fn(val_pred, y_new_val).item()
        print(f"[LoRA finetune] step {step+1}, train_loss={loss.item():.4f}, val_loss={val_loss:.4f}")

# merge weights
model.head.merge_weights_()  # after forward, no need to use A/B
with torch.no_grad():
    val_pred = model(x_val)
    val_loss = loss_fn(val_pred, y_new_val).item()
print(f"\nMerged for inference. Val loss (merged): {val_loss:.4f}")

[Pretrain] step 100, loss=0.0661
[Pretrain] step 200, loss=0.0438
[Pretrain] step 300, loss=0.0338
[Pretrain] step 400, loss=0.0285
[Pretrain] step 500, loss=0.0251
[Pretrain] step 600, loss=0.0228
[Pretrain] step 700, loss=0.0211
[Pretrain] step 800, loss=0.0198

Trainable params after freezing: 2192/39312 (5.58%). only LoRA is trainable.
[LoRA finetune] step 100, train_loss=1.3054, val_loss=1.4206
[LoRA finetune] step 200, train_loss=1.2314, val_loss=1.3741
[LoRA finetune] step 300, train_loss=1.2226, val_loss=1.3675
[LoRA finetune] step 400, train_loss=1.2155, val_loss=1.3663
[LoRA finetune] step 500, train_loss=1.2152, val_loss=1.3567
[LoRA finetune] step 600, train_loss=1.2190, val_loss=1.3624

Merged for inference. Val loss (merged): 1.2944
