先回顾一下 Adam 优化器。

设当前步的梯度为 $g_t=\nabla_\theta \mathcal{L}_t(\theta_t)$

- 一阶矩（动量）与二阶矩（RMS）
    - $m_t=\beta_1 m_{t-1} + (1-\beta_1)g_t$；
    - $v_t=\beta_2 v_{t-1} + (1-\beta_2)g_t\odot g_t$；
- 偏置校正
    - $\hat m_t=\frac{m_t}{1-\beta_1^t}$；
    - $\hat v_t=\frac{v_t}{1-\beta_2^t}$；
- 更新（不含权重衰减项）
    - $\theta_{t+1}=\theta_t-\alpha\;\frac{\hat m_t}{\sqrt{\hat v_t}+\varepsilon}$；
- 直觉：$\beta_1$ 控制一阶动量的平滑（常用 0.9）；$\beta_2$ 控制二阶动量（梯度平方）的平滑（如 0.999，相当于约 $1/(1-0.999)\approx 1000$ 步的均值窗口）；数值越大，方差估计越平滑、收敛更稳但响应更慢；

**weight-decay**
- 本质是对参数施加 L2 正则化的强度系数，用来限制参数幅度、抑制过拟合、提升泛化；
- L2 正则项：把 $\frac{\lambda}{2}\|\theta\|^2$ 加进损失里，导致梯度里多了 $\lambda\theta$；
- 但在 Adam 里这样做会让衰减项经过自适应缩放不理想；
- Decoupled Weight Decay（AdamW）：把权重衰减从损失梯度里解耦，直接在更新中衰减参数，效果更稳定、也成业界默认；
- AdamW 的参数更新（含衰减）可写成：
$$\theta \leftarrow \theta - \alpha \,\frac{\hat m}{\sqrt{\hat v}+\varepsilon}\;-\; \alpha\,\lambda\,\theta$$
- 或等价地先做 $\theta \leftarrow (1-\alpha\lambda)\theta$，再做 Adam 梯度步；
- Megatron-LM 默认用 AdamW，--weight-decay .1 即 \lambda=0.1；
- 通常偏置、LayerNorm / Norm 权重不施加衰减；

In [14]:
import math
import torch
from torch.optim.optimizer import Optimizer

class AdamW(Optimizer):
    def __init__(self, params, lr=1e-4, betas=(.9, .999), eps=1e-8, weight_decay=0.01) -> None:
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group['lr']
            beta1, beta2 = group['betas']
            eps = group['eps']
            wd = group['weight_decay']

            for p in group["params"]: # p.data p.grad
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise RuntimeError("AdamW not support sparse gradient")
                
                state = self.state[p]
                if len(state) == 0:
                    state["step"] = 0
                    state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
                
                exp_avg = state["exp_avg"] # m_t
                exp_avg_sq = state["exp_avg_sq"] # v_t

                state["step"] += 1
                t = state["step"]

                # different order
                exp_avg.mul_(beta1).add_(grad, alpha=1-beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2)

                # bias correction
                exp_avg.div_(1 - beta1 ** t)
                exp_avg_sq.div_(1 - beta2 ** t)

                if wd != 0:
                    p.mul_(1 - lr * wd)

                p.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(eps), value=-lr)

        return loss

In [34]:
import torch
import torch.nn as nn

model = nn.Sequential(
    nn.Linear(768, 256),
    nn.LayerNorm(256),
    nn.GELU(),
    nn.Linear(256, 64),
    nn.LayerNorm(64),
    nn.GELU(),
    nn.Linear(64, 10)
).to("mps")

# no decay for bias and LayerNorm
decay, no_decay = [], []
for name, p in model.named_parameters():
    if p.requires_grad is False:
        continue
    if p.ndim == 1 or name.endswith(".bias") or "norm" in name.lower():
        no_decay.append(p)
    else:
        decay.append(p)

optimizer = AdamW([
    {"params": decay, "weight_decay": 0.1},
    {"params": no_decay, "weight_decay": 0.0},
], lr=5e-4, betas=(0.9, 0.999), eps=1e-8)

x = torch.randn(32, 768, device="mps")
y = torch.randn(32, 10, device="mps")
criterion = nn.MSELoss()

model.train()
for step in range(10):
    optimizer.zero_grad(set_to_none=True)
    pred = model(x)
    loss = criterion(pred, y)
    loss.backward()

    # gradient clipping and AMP scaling if needed
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
    # torch.mps.amp.scale_loss(loss, optimizer)

    optimizer.step()
    print(f"step {step:02d} | loss {loss.item():.6f}")

step 00 | loss 1.148895
step 01 | loss 0.780953
step 02 | loss 0.713939
step 03 | loss 0.702153
step 04 | loss 0.700210
step 05 | loss 0.699914
step 06 | loss 0.699876
step 07 | loss 0.699876
step 08 | loss 0.699883
step 09 | loss 0.699890
