Skip to content

[PyTorch] Change arguments order in triton kernels to make jax-triton work#2416

Merged
phu0ngng merged 2 commits into
NVIDIA:mainfrom
tdophung:teddy/pytorch-triton
Nov 25, 2025
Merged

[PyTorch] Change arguments order in triton kernels to make jax-triton work#2416
phu0ngng merged 2 commits into
NVIDIA:mainfrom
tdophung:teddy/pytorch-triton

Conversation

@tdophung
Copy link
Copy Markdown
Collaborator

@tdophung tdophung commented Nov 24, 2025

Description

Jax-triton needs to have a specific order of arguments for it to work. Specifically, output pointers need to be at the end of the list of arguments, but before all tl.constexpr. This PR changes this order in the common triton kernels and make sure the Pytorch wrappers still work. A different PR will be send out for the Jax side implementation

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Change order of arguments in common triton kernels
  • Change order of input passed into pytorch wrappers to match the triton kernels

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung self-assigned this Nov 24, 2025
@phu0ngng phu0ngng requested review from timmoon10 and yaox12 and removed request for jberchtold-nvidia and ksivaman November 24, 2025 21:50
@phu0ngng
Copy link
Copy Markdown
Collaborator

/te-ci L0

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Nov 24, 2025

Greptile Overview

Greptile Summary

This PR reorders arguments in Triton kernels to comply with jax-triton requirements: output pointers now appear after input pointers and strides but before tl.constexpr parameters. The changes affect 5 kernels and their PyTorch wrappers:

  • _row_id_map_pass_1_kernel: Moved row_id_map_ptr and workspace_ptr after strides
  • _row_id_map_pass_3_kernel: Moved num_experts from sizes to metas section
  • _permute_kernel: Moved output_ptr and permuted_probs_ptr after strides, moved num_experts and hidden_size to metas
  • _unpermute_kernel: Moved output_ptr and unpermuted_probs_ptr after strides, moved num_experts and hidden_size to metas
  • _sort_chunks_by_map_kernel: Moved output_ptr and permuted_probs_ptr after strides, moved hidden_size to metas

All PyTorch wrapper calls were updated to pass arguments in the new order, maintaining consistency between kernel signatures and invocations.

Confidence Score: 5/5

  • This PR is safe to merge with no identified issues
  • The changes are purely mechanical argument reordering with perfect consistency between kernel definitions and wrapper calls, no logic changes, and existing tests verify correctness
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/triton/permutation.py 5/5 Reordered kernel arguments to place output pointers before tl.constexpr parameters for jax-triton compatibility
transformer_engine/pytorch/triton/permutation.py 5/5 Updated wrapper function calls to match reordered kernel signatures

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant PW as PyTorch Wrapper
    participant TK as Triton Kernel
    
    Note over User,TK: Example: permute_with_mask_map
    
    User->>PW: permute_with_mask_map(inp, row_id_map, probs, scale, ...)
    
    Note over PW: Prepare arguments in new order:<br/>1. Input pointers<br/>2. Sizes & strides<br/>3. Output pointers<br/>4. tl.constexpr params
    
    PW->>TK: _permute_kernel(<br/>  inp, row_id_map, probs, scale, permuted_scale,<br/>  scale_hidden_dim, strides...,<br/>  output, permuted_probs,<br/>  num_experts, hidden_size, flags...)
    
    Note over TK: Execute with jax-triton<br/>compatible argument order
    
    TK-->>PW: Execution complete
    PW-->>User: output, permuted_scale, permuted_probs
Loading

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator

LGTM, thanks! but will wait for someone on the TE/PyTorch side to approve

Copy link
Copy Markdown
Member

@yaox12 yaox12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. This only changes the arguments order of triton kernels, and leave the PyTorch API untouched. So it should be fine on the PyTorch side.

@phu0ngng phu0ngng merged commit 0056b98 into NVIDIA:main Nov 25, 2025
40 of 41 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants