**LARS（Layer-wise Adaptive Rate Scaling）**

核心想法：对每一层算一个信任比率（trust ratio）
$$
\text{trust\_ratio}_\ell
\;=\;
\eta \cdot \frac{\lVert w_\ell\rVert}{\lVert g_\ell \rVert + \epsilon}
$$


用它去缩放全局学习率

$$\text{lr}_\ell = \text{lr} \times \text{trust\_ratio}_\ell$$

再做 SGD (+momentum) 更新。

- 让每层的相对更新幅度 $\lVert \Delta w_\ell\rVert/\lVert w_\ell\rVert$ 更一致，避免层间尺度差异在巨型 batch 下被放大导致训练不稳；
- 常见用法是超大 batch 训练 CNN（如 ImageNet 32K + batch）；
- 通常跳过 BN / 偏置层的层级自适应与权重衰减（trust ratio 置 1 即可）；

**LAMB（Layer-wise Adaptive Moments for Batch training）**

核心想法：先像 Adam / AdamW 一样算自适应步子（按二阶动量归一化），再加一层 layer-wise trust ratio；
$$
u_\ell=\frac{\hat m_\ell}{\sqrt{\hat v_\ell}+\epsilon} \quad(\text{Adam 方向})
\qquad
r_\ell=\frac{\lVert w_\ell\rVert}{\lVert u_\ell \rVert+\epsilon}
\qquad
\Delta w_\ell=-\text{lr}\cdot r_\ell\cdot u_\ell
$$

In [None]:
import torch
from torch.optim.optimizer import Optimizer

def _exclude(p_name):
    # If the name contains 'bn', 'ln', 'bias', do not adapt to layer-wise
    n = p_name.lower()
    return ("bn" in n) or ("ln" in n) or ("bias" in n)

class LARS(Optimizer):
    def __init__(self, params, lr=0.1, momentum=0.9, weight_decay=1e-4,
                 eta=0.001, eps=1e-9, trust_clip=None):

        # support (name, param)
        named = []
        raw = []
        for p in params:
            if isinstance(p, tuple):
                named.append(p)
            else:
                raw.append(p)

        if named:
            param_groups = [{"params":[p for _, p in named], "names":[n for n, _ in named]}]
        else:
            param_groups = [{"params": raw, "names": ["" for _ in raw]}]

        defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay,
                        eta=eta, eps=eps, trust_clip=trust_clip)
        super().__init__(param_groups, defaults)

        # allocate state for each param
        for group in self.param_groups:
            for p in group["params"]:
                self.state[p]["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None if closure is None else closure()

        for group in self.param_groups:
            lr = group["lr"]
            mom = group["momentum"]
            wd  = group["weight_decay"]
            eta = group["eta"]
            eps = group["eps"]
            clip= group["trust_clip"]
            names = group.get("names", [""]*len(group["params"]))

            for name, p in zip(names, group["params"]):
                if p.grad is None:
                    continue

                g = p.grad

                # decoupled weight decay
                if wd != 0 and not _exclude(name):
                    p.add_(p, alpha=-lr*wd)

                # momentum
                buf = self.state[p]["momentum_buffer"]
                buf.mul_(mom).add_(g)

                # layer-wise trust ratio (not adapt to BN/LN/bias)
                if _exclude(name):
                    trust_ratio = 1.0
                else:
                    w_norm = p.norm()
                    u_norm = buf.norm()
                    if w_norm > 0 and u_norm > 0:
                        trust_ratio = eta * (w_norm / (u_norm + eps))
                        if clip is not None:
                            trust_ratio = min(trust_ratio, clip)
                    else:
                        trust_ratio = 1.0

                # update
                p.add_(buf, alpha=-lr * trust_ratio)

        return loss