Skip to content

Enable selective recompute for norm_out in GDN layers #4715

Draft
xuantengh wants to merge 1 commit into
NVIDIA:mainfrom
xuantengh:xuantengh/gdn_recompute
Draft

Enable selective recompute for norm_out in GDN layers #4715
xuantengh wants to merge 1 commit into
NVIDIA:mainfrom
xuantengh:xuantengh/gdn_recompute

Conversation

@xuantengh
Copy link
Copy Markdown

This PR enables the selective recompute for the gated norm step in GDN layers.

You may enable this by adding the following options via MBridge recipe:

model.recompute_granularity=selective
model.recompute_modules=[gdn_norm_out]

The recompute module name may subject to change.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 10, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@xuantengh
Copy link
Copy Markdown
Author

/claude review

Comment on lines +428 to +442
if self.recompute_norm_out:
self.norm_out_checkpoint = tensor_parallel.CheckpointWithoutOutput()
norm_out = self.norm_out_checkpoint.checkpoint(
self._run_gated_norm_and_a2a, core_attn_out, gate
)
else:
norm_out = self._run_gated_norm_and_a2a(core_attn_out, gate)

# Output projection
nvtx_range_push(suffix="out_proj")
out, out_bias = self.out_proj(norm_out)
nvtx_range_pop(suffix="out_proj")

if self.recompute_norm_out:
self.norm_out_checkpoint.discard_output_and_register_recompute(out)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Other recompute modules (mla_up_proj, layernorm, moe_act) have unit test coverage for the forward/backward pass with recompute enabled. Consider adding a test case in tests/unit_tests/ssm/test_gated_delta_net.py that sets recompute_granularity="selective" and recompute_modules=["gdn_norm_out"] and verifies the forward+backward pass produces correct gradients. This would help catch regressions if the CheckpointWithoutOutput contract changes.

Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. The CheckpointWithoutOutput pattern matches the established usage in transformer_layer.py and multi_latent_attention.py. Config validation for gdn_norm_out is correct. Left a minor suggestion about adding unit test coverage.

@xuantengh xuantengh self-assigned this May 10, 2026
@xuantengh
Copy link
Copy Markdown
Author

/ok to test cdaca66

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant