Skip to content

[Common][PyTorch] Add z_loss_weight to parallel_cross_entropy#2707

Open
bassoy wants to merge 2 commits intoNVIDIA:mainfrom
bassoy:add_zloss_to_parallel_cross_entropy
Open

[Common][PyTorch] Add z_loss_weight to parallel_cross_entropy#2707
bassoy wants to merge 2 commits intoNVIDIA:mainfrom
bassoy:add_zloss_to_parallel_cross_entropy

Conversation

@bassoy
Copy link

@bassoy bassoy commented Feb 26, 2026

PR Draft — #2707

Title: [Common][PyTorch] Add z_loss_weight to parallel_cross_entropy


Description

Adds z-loss regularization to parallel_cross_entropy. Z-loss penalizes large logit magnitudes by adding z_loss_weight * log(Z)^2 per token to the loss, where log(Z) = log(sum(exp(logits))) is the log-sum-exp (see
ST-MoE, arxiv.org/abs/2202.08906). This stabilizes training by keeping logits in a numerically well-behaved range.

Type of change

  • New feature (non-breaking change which adds functionality)

Changes

The Triton kernel already computes m (running max) and d (sum of shifted exponentials) as part of the online softmax. lse = m + log(d) is a free byproduct which means no extra data movement required.

  • Forward: compute lse = m + log(d) using variables and add z_loss_weight * lse^2 to the per-token loss. Reuse lse in the label-smoothing path (smooth_loss previously recomputed m + log(d)).
  • Backward: scale the softmax gradient by (1 + 2 * z_loss_weight * lse), derived from d/dx_i[z_loss_weight * lse^2] = 2 * z_loss_weight * lse * softmax(x_i).
  • Dead-code elimination: z_loss_weight is tl.constexpr enabling triton to eliminate all z-loss branches at compile time when z_loss_weight=0.0.
  • API: parallel_cross_entropy(..., z_loss_weight=0.0) is backward compatible. Hence, default behavior is unchanged.

Tests

Z-loss tests extend the existing test infrastructure rather than introducing separate helpers. generate_infra accepts an optional z_loss_weight parameter: when > 0, it builds a PyTorch reference function that computes F.cross_entropy + z_loss_weight * lse^2 in FP32. one_iteration_test passes z_loss_weight through to parallel_cross_entropy. This keeps all z-loss tests in the same pattern as the existing suite.

The z_loss_weight == 0.0 path in generate_infra uses torch.nn.CrossEntropyLoss with ignore_index forwarded, keeping it consistent with the z-loss reference path.

7 new tests, 15 total pass (A40, single GPU):

Test What it verifies
test_z_loss FP32 loss and gradients match PyTorch reference (5 random iterations, random swap_dim)
test_z_loss_bfloat16 Same as above with BF16 input (3 iterations)
test_z_loss_with_ignore_idx Z-loss + ignored tokens: loss and gradients correct (5 iterations)
test_z_loss_zero_weight z_loss_weight=0.0 produces bit-identical results to the default (no z-loss)
test_z_loss_reduced Z-loss + reduce_loss=True: reduced loss and gradients correct (5 iterations)
test_z_loss_reduced_with_ignore_idx Z-loss + reduce_loss=True + ignored tokens: correct (5 iterations)
test_z_loss_label_smoothing Z-loss + label_smoothing=0.1: both features interact correctly (3 iterations)

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@bassoy bassoy force-pushed the add_zloss_to_parallel_cross_entropy branch from 7f11aa2 to 70d3f84 Compare March 20, 2026 21:49
@bassoy bassoy changed the title [Common][PyTorch] Add z_loss_weight and log_sum_exp output to parallel_cross_entropy [Common][PyTorch] Add z_loss_weight to parallel_cross_entropy Mar 20, 2026
@bassoy bassoy marked this pull request as ready for review March 20, 2026 21:57
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 20, 2026

Greptile Summary

This PR adds z-loss regularization (z_loss_weight * log(Z)^2 per token) to parallel_cross_entropy. The feature is backward-compatible (z_loss_weight=0.0 default) and leverages the intermediate lse = m + log(d) already computed by the online-softmax kernel, requiring no extra data movement.

Key points:

  • Forward: lse is now computed once and reused for both label-smoothing and z-loss forward terms, eliminating the previous redundant m + log(d) computation in the label-smoothing path.
  • Backward: z-loss gradient (2 * z_loss_weight * lse * softmax(x_i)) is applied to the pure softmax value before the label-smoothing eps subtraction, making the combined gradient softmax * (1 + 2*z*lse) - eps — mathematically correct.
  • Validation: The guard not (z_loss_weight >= 0.0 and z_loss_weight != float("inf")) correctly rejects negative values, NaN (since nan >= 0.0 is False), and inf.
  • Tests: All previously raised concerns (backward coverage for zero-weight test, ignore_idx parameterisation, reduce_loss + ignore_idx combination) are addressed in the new suite.

Confidence Score: 5/5

  • This PR is safe to merge; all previously flagged concerns have been resolved and no new issues were found.
  • Every concern raised in prior review threads has been addressed: the gradient is applied to pure softmax before eps subtraction (correcting the label-smoothing interaction), the zero-weight test now exercises the backward path, validation correctly rejects NaN/inf, generate_infra properly forwards ignore_idx in all code paths, and the missing reduce_loss + ignore_idx test is present. The mathematical derivation for both forward and backward is correct, and the 7 new tests cover all meaningful combinations.
  • No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/common/triton/cross_entropy.py Adds z_loss_weight as tl.constexpr to cross_entropy_kernel; computes lse = m + log(d) once, reuses it in label-smoothing path, z-loss forward term, and z-loss gradient (applied to softmax before eps subtraction — mathematically correct). Previously flagged gradient issue with label_smoothing is resolved.
transformer_engine/pytorch/cross_entropy.py Adds z_loss_weight parameter to parallel_cross_entropy and CrossEntropyFunction.forward; input validation correctly rejects negative, NaN, and infinite values; the backward does not need z_loss_weight because gradients are pre-computed in-place during the forward kernel.
transformer_engine/pytorch/triton/cross_entropy.py Wires z_loss_weight through cross_entropy_forward to the Triton kernel call; default value of 0.0 preserves backward compatibility.
tests/pytorch/test_parallel_cross_entropy.py Adds 7 z-loss tests covering FP32/BF16, ignore_idx, reduce_loss, label_smoothing, and zero-weight bit-identical regression; generate_infra now parameterises ignore_idx and correctly forwards it in both the z_loss and non-z_loss reference paths.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["parallel_cross_entropy(inp, target, z_loss_weight)"] --> B{z_loss_weight valid?}
    B -- "no: neg/nan/inf" --> C[raise ValueError]
    B -- yes --> D[CrossEntropyFunction.apply]
    D --> E[cross_entropy_forward]
    E --> F["online_softmax_kernel: compute m, d, X_y per row"]
    F --> G["cross_entropy_kernel: one program per row"]
    G --> H{y == ignore_idx?}
    H -- yes --> I["zero gradient, return early, loss = 0"]
    H -- no --> J["Compute lse = m + log(d)"]
    J --> K["Gradient pass: softmax then optionally scale by 1+2z*lse then subtract eps then divide by n_non_ignore"]
    J --> L["Loss: lse - ori_X_y, add label_smoothing term, add z*lse*lse"]
    K --> M["Special-case dx_y: subtract 1-ls divided by n_non_ignore"]
    L --> N[Store per-token loss]
    M --> N
    N --> O["element_mul_kernel: scale grad by upstream grad_output"]
    O --> P[Return loss tensor]
Loading

Reviews (5): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +195 to +197
# Z-loss gradient: d/dx_i[z_loss_weight * lse^2] = 2 * z_loss_weight * lse * softmax(x_i).
if z_loss_weight > 0:
X_block = X_block * (1.0 + 2.0 * z_loss_weight * lse)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Incorrect z-loss gradient when label_smoothing > 0

At this point X_block = (softmax(x_i) - eps) / N (or softmax - eps without reduction). Multiplying the combined CE gradient by (1 + 2 * z_loss_weight * lse) expands to:

(softmax - eps)/N * (1 + 2*z*lse)
= (softmax - eps)/N + (softmax - eps)/N * 2*z*lse

But the correct z-loss gradient is purely softmax/N * 2 * z * lse — the z-loss term should be additive on top of the CE gradient, not multiplicative against the entire (softmax - eps) expression. The error introduced is -eps/N * 2 * z_loss_weight * lse per element.

For typical training settings (label_smoothing=0.1, V=64000, z_loss_weight=0.001, lse≈11) the error is on the order of 3e-8, which is below float32 precision for large vocabularies — explaining why test_z_loss_label_smoothing still passes. However, for small vocabularies (e.g. V=32) the error becomes measurable and the implementation is mathematically incorrect.

The correct approach is to add the z-loss gradient additively, using the pre-eps softmax value:

        if reduce_loss:
            X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
        else:
            X_block = tl.exp(X_block - m) / d - eps
        # Z-loss gradient: 2 * z_loss_weight * lse * softmax(x_i), additive to CE gradient.
        if z_loss_weight > 0:
            softmax_i = tl.exp(X_block_fp32 - m) / d  # pure softmax, before subtracting eps
            if reduce_loss:
                X_block = X_block + softmax_i * (2.0 * z_loss_weight * lse) / n_non_ignore
            else:
                X_block = X_block + softmax_i * (2.0 * z_loss_weight * lse)

where X_block_fp32 is the logit block before the CE computation (currently loaded at the top of the loop).

Comment on lines +204 to +215
def test_z_loss_zero_weight(self):
self.generate_infra(False, 0)
self.generate_input(torch.float32, False, False)
loss_base = self.test_loss_func(self.input_test.clone(), self.tar_test)
loss_zero = self.test_loss_func(self.input_test.clone(), self.tar_test, z_loss_weight=0.0)
assert torch.equal(
loss_base, loss_zero
), "z_loss_weight=0.0 must be bit-identical to the default"
self.input_test = None
self.input_ref = None
self.tar_test = None
self.tar_ref = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 test_z_loss_zero_weight only validates the forward pass

The test clones the input tensor but never calls .requires_grad_(True), so no gradient is accumulated and the backward path is never exercised. The Triton kernel eliminates the z-loss branches at compile time via tl.constexpr, so validating that the gradient is also bit-identical for z_loss_weight=0.0 would strengthen the regression value of this test.

Consider adding backward verification:

def test_z_loss_zero_weight(self):
    self.generate_infra(False, 0)
    self.generate_input(torch.float32, False, False)

    inp_base = self.input_test.clone().requires_grad_(True)
    inp_zero = self.input_test.clone().requires_grad_(True)

    loss_base = self.test_loss_func(inp_base, self.tar_test)
    loss_zero = self.test_loss_func(inp_zero, self.tar_test, z_loss_weight=0.0)

    assert torch.equal(loss_base, loss_zero), "z_loss_weight=0.0 must be bit-identical to the default"

    loss_base.sum().backward()
    loss_zero.sum().backward()

    assert torch.equal(inp_base.grad, inp_zero.grad), \
        "Gradients with z_loss_weight=0.0 must be bit-identical to the default"

    self.input_test = self.input_ref = self.tar_test = self.tar_ref = None

dist_process_group: Optional[torch.distributed.ProcessGroup] = None,
ignore_idx: int = -100,
is_cg_capturable: bool = False,
z_loss_weight: float = 0.0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 No validation that z_loss_weight is non-negative

A negative value is mathematically well-formed but semantically inverts the regularization (rewarding large logit magnitudes). Given the docstring describes this as a "regularization weight", an early guard against negative values would make the API safer and the intent explicit:

Suggested change
z_loss_weight: float = 0.0,
z_loss_weight: float = 0.0,

Consider adding before the CrossEntropyFunction.apply(...) call:

if z_loss_weight < 0.0:
    raise ValueError(f"z_loss_weight must be non-negative, got {z_loss_weight}")

@bassoy bassoy force-pushed the add_zloss_to_parallel_cross_entropy branch from 056ce5f to bb47312 Compare March 20, 2026 22:19
n_non_ignore,
reduce_loss: tl.constexpr,
label_smoothing: tl.constexpr,
z_loss_weight: tl.constexpr,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 tl.constexpr specialization will recompile for every unique float value

z_loss_weight is declared tl.constexpr, which means Triton compiles a separate kernel for each unique Python float value passed. The PR describes this as intentional for dead-code elimination when z_loss_weight=0.0, and for a fixed training hyperparameter that's fine. However, if callers ever want to anneal or schedule z_loss_weight across training steps (e.g. a warmup from 0 → 0.001), every distinct float encountered will trigger a fresh JIT compilation, stalling the training loop.

Consider documenting this behaviour in the docstring of both cross_entropy_kernel and parallel_cross_entropy:

z_loss_weight (float): Weight for z-loss regularization. Adds z_loss_weight * log(Z)^2 per token.
    This value is used as a Triton compile-time constant (tl.constexpr); varying it across
    calls will trigger kernel recompilation. Use a fixed value during training.

inp = inp.float()
ce = F.cross_entropy(inp, tar, reduction="none", label_smoothing=label_smoothing)
z_pen = z_loss_weight * torch.square(torch.logsumexp(inp, dim=-1))
z_pen[tar == -100] = 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Ignore index hardcoded to -100 in reference function

z_pen[tar == -100] = 0.0 hardcodes the ignore index rather than closing over the ignore_idx value that will be used in the actual parallel_cross_entropy call. Currently all tests in the suite pass ignore_idx=-100 (the default), so this is consistent — but if a future test exercises a non-default ignore index, the reference would silently zero the wrong tokens and the comparison would yield a false positive.

Consider parameterising the closure:

Suggested change
z_pen[tar == -100] = 0.0
z_pen[tar == ignore_idx] = 0.0

And updating generate_infra to accept ignore_idx: int = -100 so it can be forwarded.

@bassoy bassoy force-pushed the add_zloss_to_parallel_cross_entropy branch from bb47312 to ed83839 Compare March 20, 2026 22:29
@bassoy bassoy marked this pull request as draft March 20, 2026 22:31
@bassoy bassoy force-pushed the add_zloss_to_parallel_cross_entropy branch from ed83839 to da503ed Compare March 20, 2026 22:33
@bassoy bassoy marked this pull request as ready for review March 20, 2026 22:36
Comment on lines +249 to +259
def test_z_loss_reduced(self):
self.generate_iters(5)
self.generate_infra(True, 0, z_loss_weight=0.001)
for i in range(self.iters):
self.one_iteration_test(
dtype=torch.float32,
swap_dim=random.choice([True, False]),
label_smoothing=0,
reduce_loss=True,
z_loss_weight=0.001,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing test combination: reduce_loss=True + z-loss + ignore_idx

The test suite covers reduce_loss=True + z-loss (test_z_loss_reduced) and reduce_loss=False + z-loss + ignore_idx (test_z_loss_with_ignore_idx), but no test exercises all three together. This is the most semantically interesting combination: n_non_ignore is used to normalize both the loss value (in Python) and the gradient (in the Triton kernel, line 198), so an incorrect interaction would only appear when tokens are actually masked and reduction is active.

Consider adding:

def test_z_loss_reduced_with_ignore_idx(self):
    self.generate_iters(5)
    self.generate_infra(True, 0, z_loss_weight=0.001)
    for i in range(self.iters):
        self.one_iteration_test(
            dtype=torch.float32,
            swap_dim=random.choice([True, False]),
            label_smoothing=0,
            reduce_loss=True,
            ignore_idx=True,
            z_loss_weight=0.001,
        )

Comment on lines +32 to +40
def ref_with_zloss(inp, tar):
inp = inp.float()
ce = F.cross_entropy(inp, tar, reduction="none", label_smoothing=label_smoothing)
z_pen = z_loss_weight * torch.square(torch.logsumexp(inp, dim=-1))
z_pen[tar == ignore_idx] = 0.0
loss = ce + z_pen
if reduce_loss:
loss = loss.sum() / (tar != ignore_idx).sum()
return loss
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 F.cross_entropy does not receive the parameterized ignore_index

generate_infra now accepts an ignore_idx parameter and correctly uses it to zero out z_pen, but F.cross_entropy is called without ignore_index=ignore_idx. PyTorch's default is -100, so all current tests pass since generate_input always uses -100. However, if a future test passes a non-default ignore_idx, the CE component of the reference would still ignore -100 while the real kernel would ignore the custom index, producing a silently incorrect reference loss and false-passing gradient tests.

def ref_with_zloss(inp, tar):
    inp = inp.float()
    ce = F.cross_entropy(
        inp, tar, reduction="none",
        label_smoothing=label_smoothing,
        ignore_index=ignore_idx,   # <-- forward the parameter
    )
    z_pen = z_loss_weight * torch.square(torch.logsumexp(inp, dim=-1))
    z_pen[tar == ignore_idx] = 0.0
    ...

@bassoy bassoy force-pushed the add_zloss_to_parallel_cross_entropy branch from b337f49 to 13ffa35 Compare March 20, 2026 22:55
Comment on lines +27 to +29
self.ref_loss_func = torch.nn.CrossEntropyLoss(
label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 ignore_idx not forwarded to CrossEntropyLoss

generate_infra now accepts an ignore_idx parameter, and the z_loss_weight > 0 branch correctly passes it to both F.cross_entropy(..., ignore_index=ignore_idx) and z_pen[tar == ignore_idx]. However, the z_loss_weight == 0.0 branch silently falls back to PyTorch's default (-100), ignoring the parameter:

self.ref_loss_func = torch.nn.CrossEntropyLoss(
    label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none"
    # ignore_index=ignore_idx is missing here
)

All current tests happen to pass ignore_idx=-100 (the default), so there is no visible failure now. But if any test ever calls generate_infra(..., z_loss_weight=0.0, ignore_idx=42), the reference would still ignore token id -100 instead of 42, producing a silently incorrect reference and a false-passing comparison.

Suggested change
self.ref_loss_func = torch.nn.CrossEntropyLoss(
label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none"
)
self.ref_loss_func = torch.nn.CrossEntropyLoss(
label_smoothing=label_smoothing,
reduction="mean" if reduce_loss else "none",
ignore_index=ignore_idx,
)

Comment on lines +152 to +153
if z_loss_weight < 0:
raise ValueError(f"z_loss_weight must be non-negative, got {z_loss_weight}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Non-finite z_loss_weight bypasses validation

The current guard only rejects strictly negative values. float('nan') satisfies nan < 0 == False (NaN comparisons always return False) and float('inf') satisfies inf < 0 == False, so both slip through undetected:

  • z_loss_weight=float('nan'): the Triton tl.constexpr comparison nan > 0 evaluates to False at compile time, so the z-loss branch is silently skipped — users see no z-loss even though they passed a non-zero value.
  • z_loss_weight=float('inf'): the branch is taken, and loss += inf * lse * lse will produce inf/nan losses, immediately destabilising training.

Consider expanding the guard to also exclude non-finite values:

Suggested change
if z_loss_weight < 0:
raise ValueError(f"z_loss_weight must be non-negative, got {z_loss_weight}")
if not (z_loss_weight >= 0.0 and z_loss_weight != float("inf")):
raise ValueError(f"z_loss_weight must be a finite non-negative number, got {z_loss_weight}")

Add z-loss regularization (z_loss_weight * log(Z)^2 per token) to the
Triton cross-entropy kernel. The z_loss_weight parameter is a
tl.constexpr, so it is dead-code-eliminated when set to 0.0.

Forward: adds z_loss_weight * lse^2 to per-token loss.
Backward: scales softmax gradient by (1 + 2 * z_loss_weight * lse).

Tests: extend existing test infrastructure with z_loss_weight parameter
in generate_infra and one_iteration_test. Z-loss tests cover FP32,
BF16, ignore_idx, and zero-weight identity.

Signed-off-by: Cem Bassoy <cem.bassoy@deepl.com>
@bassoy bassoy force-pushed the add_zloss_to_parallel_cross_entropy branch from 33677c0 to e572feb Compare March 22, 2026 11:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant