Skip to content

[JAX] Set BSHD as default in Unfused DPA, DPA and MHA API calls#2392

Merged
KshitijLakhani merged 5 commits into
NVIDIA:mainfrom
KshitijLakhani:klakhani/main/set-bshd-default-mha-and-dpa
Nov 21, 2025
Merged

[JAX] Set BSHD as default in Unfused DPA, DPA and MHA API calls#2392
KshitijLakhani merged 5 commits into
NVIDIA:mainfrom
KshitijLakhani:klakhani/main/set-bshd-default-mha-and-dpa

Conversation

@KshitijLakhani
Copy link
Copy Markdown
Collaborator

@KshitijLakhani KshitijLakhani commented Nov 17, 2025

Description

Maintaining consistency across layers by setting BSHD as the default type for DPA, Unfused DPA and MHA to match the already existing defaults of BSHD in TransformerLayer and Fused DPA

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

MHA, DPA (and Unfused DPA) API changes to accept BSHD as default
TE JAX tuts changed to reflect the same

In user facing APIs, DPA and MHA, I have set the default to be None. I set transpose_batch_sequence to False in post_init if transpose_batch_sequence is None.
I am basically use None defaults as way to indicate to the post init state whether the user is explicitly passing the param or not. Three cases can arise:

  1. If the user is explicitly passing True, no changes for them
  2. If the user is explicitly passing False, no changes for them
  3. If the user is passing nothing, thereby explicitly relying on the APIs defaults (in this PR, None) then a warning is sent to the user informing them about the changed defaults and the transpose_batch_sequence is set to False (i.e. new defaults)

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
@KshitijLakhani
Copy link
Copy Markdown
Collaborator Author

/te-ci jax

1 similar comment
@KshitijLakhani
Copy link
Copy Markdown
Collaborator Author

/te-ci jax

@KshitijLakhani KshitijLakhani marked this pull request as ready for review November 17, 2025 21:57
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Nov 17, 2025

Greptile Overview

Greptile Summary

This PR standardizes the default tensor layout to BSHD (batch, sequence, heads, dimension) across TE JAX attention modules by changing transpose_batch_sequence defaults from True to False in DotProductAttention and MultiHeadAttention. The approach uses a temporary None default with conditional warning in __post_init__ to inform users relying on the old defaults about the breaking change.

Key Changes:

  • Changed transpose_batch_sequence type from bool to bool | None with default None in DotProductAttention and MultiHeadAttention
  • Added __post_init__ logic that only warns when users don't explicitly pass the parameter (when value is None)
  • Sets transpose_batch_sequence=False after warning to establish new BSHD default
  • Updated _UnfusedDotProductAttention to default to False directly
  • Removed explicit transpose_batch_sequence=False from quickstart notebook examples since it's now the default
  • Added TODO comments to reset to bool with False default in v2.12

Previous Feedback Addressed:
The latest commit (76eafb7) adds if self.transpose_batch_sequence is None: checks before warnings, ensuring warnings only fire when users rely on defaults rather than unconditionally on every instantiation.

Confidence Score: 4/5

  • This PR is safe to merge with the caveat that it's an intentional breaking change requiring user code updates
  • The implementation correctly addresses previous feedback by conditionally warning only when users rely on defaults. The approach using None as a sentinel value to detect default reliance is appropriate for this transitional period. Documentation is updated consistently. Score is 4 rather than 5 because: (1) this is a breaking change that will affect existing user code, (2) mutating dataclass attributes in __post_init__ works in Flax but is unconventional, and (3) warnings may still appear multiple times if users instantiate modules repeatedly in training loops, though Python's default warning behavior should deduplicate them.
  • No files require special attention - the changes are well-implemented and consistent

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/flax/transformer.py 4/5 Changes default for transpose_batch_sequence from True to None (resolving to False) in DotProductAttention and MultiHeadAttention, with conditional warnings when users rely on defaults. Addresses previous feedback by adding None checks before warnings.
docs/examples/quickstart_jax.ipynb 5/5 Removes explicit transpose_batch_sequence=False parameter from DotProductAttention calls since it's now the default. Clean documentation update that aligns with API changes.

Sequence Diagram

sequenceDiagram
    participant User
    participant DPA as DotProductAttention
    participant MHA as MultiHeadAttention
    participant PostInit as __post_init__

    Note over User,PostInit: Scenario 1: User relies on defaults (no explicit arg)
    User->>DPA: DotProductAttention(num_heads=8)
    Note right of DPA: transpose_batch_sequence=None
    DPA->>PostInit: __post_init__()
    PostInit->>PostInit: Check: transpose_batch_sequence is None?
    PostInit->>PostInit: Yes → Emit warning
    PostInit->>PostInit: Set transpose_batch_sequence=False
    PostInit-->>User: Instance created (BSHD format)

    Note over User,PostInit: Scenario 2: User explicitly sets True
    User->>DPA: DotProductAttention(transpose_batch_sequence=True)
    Note right of DPA: transpose_batch_sequence=True
    DPA->>PostInit: __post_init__()
    PostInit->>PostInit: Check: transpose_batch_sequence is None?
    PostInit->>PostInit: No → Skip warning
    PostInit-->>User: Instance created (SBHD format, no warning)

    Note over User,PostInit: Scenario 3: User explicitly sets False
    User->>MHA: MultiHeadAttention(transpose_batch_sequence=False)
    Note right of MHA: transpose_batch_sequence=False
    MHA->>PostInit: __post_init__()
    PostInit->>PostInit: Check: transpose_batch_sequence is None?
    PostInit->>PostInit: No → Skip warning
    PostInit-->>User: Instance created (BSHD format, no warning)
Loading

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format

Comment on lines +543 to +549
def __post_init__(self):
# TODO(KshitijLakhani): Remove warning in TransformerEngine v2.13
warnings.warn(
"transpose_batch_sequence defaults to False in DotProductAttention starting"
" TransformerEngine v2.10"
)
super().__post_init__()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Warning fires unconditionally on every DotProductAttention instantiation, even when users haven't relied on the old default. Should only warn when user's code would break (when they didn't explicitly set transpose_batch_sequence and need the old behavior).

Suggested change
def __post_init__(self):
# TODO(KshitijLakhani): Remove warning in TransformerEngine v2.13
warnings.warn(
"transpose_batch_sequence defaults to False in DotProductAttention starting"
" TransformerEngine v2.10"
)
super().__post_init__()
def __post_init__(self):
# TODO(KshitijLakhani): Remove warning in TransformerEngine v2.13
# Only warn if behavior might have changed for existing code
if not hasattr(self, '_transpose_batch_sequence_explicitly_set'):
warnings.warn(
"transpose_batch_sequence defaults to False in DotProductAttention starting"
" TransformerEngine v2.10. Explicitly set transpose_batch_sequence=True to"
" preserve old behavior.",
FutureWarning,
stacklevel=2
)
super().__post_init__()

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a FutureWarning and stacklevel=2 make sense to me so will add them.
Finding a way to warn users only if transpose_batch_sequence has been explicitly set seems appealing and was considered but it may require relying on an additional flag _transpose_batch_sequence_explicitly_set as mentioned above or maybe I could do something like this (which also isnt very appealing)

  • Set transpose_batch_sequence default as None/sentinel object and in DPA,
  • Check to see if transpose_batch_sequence is set to None/sentinel object and then set default value to false else use the provided value
  • If default, warn else do not warn

This may probably populate the auto generated docs with a not so clear idea

Copy link
Copy Markdown
Collaborator Author

@KshitijLakhani KshitijLakhani Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgoldfarb-nvidia your thoughts on this ? - I think it's unnecessary complexity (and less clear a solution) for code that will be thrown away in the future (couple release down the line at most - although Charlene thinks one is enough as we'll be adding a note for this in the release notes) , even though I agree that the warnings could be unnecessary nuisance for users.
The WAR does not seem to elegant. I'm inclined to bite the bullet on this one but would like your thoughts
Thanks !

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UPDATE: I added a commit to set transpose_batch_sequence to None as "default" and in post_init I then set the true default to False and send out a warning. This ensures lesser spamming to an unintended audience

Comment thread transformer_engine/jax/flax/transformer.py Outdated
@KshitijLakhani KshitijLakhani force-pushed the klakhani/main/set-bshd-default-mha-and-dpa branch from efe1a7f to e72c901 Compare November 20, 2025 19:26
@KshitijLakhani
Copy link
Copy Markdown
Collaborator Author

/te-ci jax

Comment thread docs/examples/quickstart_jax.ipynb Outdated
@@ -589,7 +588,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": null,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these notebook changes are just from running this notebook to test it, right? Lmk if these changes are anything functional, otherwise let's remove from this PR. Thanks!

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching it
My bad. Fixed it!

Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
…instead of SBHD

Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
…or transpose_batch_sequence

Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/main/set-bshd-default-mha-and-dpa branch from 5a45484 to 76eafb7 Compare November 21, 2025 00:58
Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@KshitijLakhani
Copy link
Copy Markdown
Collaborator Author

/te-ci jax

@KshitijLakhani
Copy link
Copy Markdown
Collaborator Author

CI pipeline passes: 38908560 (jobs had to be re-run)

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator

LGTM pending CI, thanks!

I checked and MaxText is already setting transpose_batch_sequence to False explicitly so no changes will be needed there: https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/attention_op.py#L1400

@KshitijLakhani KshitijLakhani merged commit beed55b into NVIDIA:main Nov 21, 2025
10 of 13 checks passed
@KshitijLakhani
Copy link
Copy Markdown
Collaborator Author

@ksivaman just to bring to your attention for release documentation / breaking change documentation

KshitijLakhani added a commit that referenced this pull request Nov 23, 2025
* Make BSHD default for Unfused DPA, DPA and MHA in TE JAX

Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>

* Remove explicit transpose_batch set for BSHD for DPA in JAX quickstart

Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>

* Add warnings in DPA and MHA to warn users of change defaults to BSHD instead of SBHD

Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>

* Minimize the scope of when to trigger warnings for changed defaults for transpose_batch_sequence

Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants