Skip to content

[PyTorch] Change order of args in another permutation triton kernel #2488

Merged
tdophung merged 1 commit into
NVIDIA:mainfrom
tdophung:teddy/more_arg_reroder_triton_perm
Dec 9, 2025
Merged

[PyTorch] Change order of args in another permutation triton kernel #2488
tdophung merged 1 commit into
NVIDIA:mainfrom
tdophung:teddy/more_arg_reroder_triton_perm

Conversation

@tdophung
Copy link
Copy Markdown
Collaborator

@tdophung tdophung commented Dec 9, 2025

Description

Change order of args in function that I missed from #2416

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

change order of arhgs in _unpermute_bwd_with_merging_probs_kernel

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung
Copy link
Copy Markdown
Collaborator Author

tdophung commented Dec 9, 2025

/te_ci L0 pytorch

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Dec 9, 2025

Greptile Overview

Greptile Summary

Completes the argument reordering work from PR #2416 by updating _unpermute_bwd_with_merging_probs_kernel to follow the JAX-Triton compatible parameter pattern: input pointers, strides, output pointers, then metas.

  • Moved fwd_input_grad_ptr and merging_probs_grad_ptr from input section to output section
  • Moved num_experts and hidden_size from sizes section to metas section
  • Updated the kernel call site to match the new parameter order
  • All arguments correctly align between caller and kernel definition

Confidence Score: 5/5

  • This PR is safe to merge with no risk
  • This is a straightforward parameter reordering that completes the work from PR [PyTorch] Change arguments order in triton kernels to make jax-triton work #2416. The changes are mechanical, both the kernel definition and call site are updated consistently, and the argument alignment has been verified to be correct. No logic changes were made.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/triton/permutation.py 5/5 Reordered kernel parameters to follow standard pattern: input pointers, strides, output pointers, then metas
transformer_engine/pytorch/triton/permutation.py 5/5 Updated kernel call to match reordered parameters, ensuring correct argument alignment

Sequence Diagram

sequenceDiagram
    participant Caller as PyTorch Layer<br/>(permutation.py)
    participant Wrapper as unpermute_with_mask_map_bwd_with_merging_probs<br/>(triton/permutation.py)
    participant Kernel as _unpermute_bwd_with_merging_probs_kernel<br/>(common/triton/permutation.py)

    Caller->>Wrapper: Call with tensors and parameters
    Note over Wrapper: Create output tensors:<br/>act_grad, merging_probs_grad
    Wrapper->>Kernel: Pass args in JAX-Triton compatible order:<br/>1. Input pointers (fwd_output_grad, fwd_input, etc.)<br/>2. Strides (8 stride parameters)<br/>3. Output pointers (act_grad, merging_probs_grad)<br/>4. Metas (num_experts, hidden_size, etc.)
    Note over Kernel: Process unpermute backward pass<br/>with merging probabilities
    Kernel-->>Wrapper: Writes to output tensors
    Wrapper-->>Caller: Return act_grad, merging_probs_grad
Loading

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator

/te-ci L0 pytorch

@tdophung tdophung merged commit e05f87e into NVIDIA:main Dec 9, 2025
15 of 16 checks passed
KshitijLakhani pushed a commit that referenced this pull request Dec 15, 2025
…2488)

change order

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung deleted the teddy/more_arg_reroder_triton_perm branch December 23, 2025 00:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants