[JAX] Set BSHD as default in Unfused DPA, DPA and MHA API calls#2392
Conversation
Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
|
/te-ci jax |
1 similar comment
|
/te-ci jax |
Greptile OverviewGreptile SummaryThis PR standardizes the default tensor layout to BSHD (batch, sequence, heads, dimension) across TE JAX attention modules by changing Key Changes:
Previous Feedback Addressed: Confidence Score: 4/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User
participant DPA as DotProductAttention
participant MHA as MultiHeadAttention
participant PostInit as __post_init__
Note over User,PostInit: Scenario 1: User relies on defaults (no explicit arg)
User->>DPA: DotProductAttention(num_heads=8)
Note right of DPA: transpose_batch_sequence=None
DPA->>PostInit: __post_init__()
PostInit->>PostInit: Check: transpose_batch_sequence is None?
PostInit->>PostInit: Yes → Emit warning
PostInit->>PostInit: Set transpose_batch_sequence=False
PostInit-->>User: Instance created (BSHD format)
Note over User,PostInit: Scenario 2: User explicitly sets True
User->>DPA: DotProductAttention(transpose_batch_sequence=True)
Note right of DPA: transpose_batch_sequence=True
DPA->>PostInit: __post_init__()
PostInit->>PostInit: Check: transpose_batch_sequence is None?
PostInit->>PostInit: No → Skip warning
PostInit-->>User: Instance created (SBHD format, no warning)
Note over User,PostInit: Scenario 3: User explicitly sets False
User->>MHA: MultiHeadAttention(transpose_batch_sequence=False)
Note right of MHA: transpose_batch_sequence=False
MHA->>PostInit: __post_init__()
PostInit->>PostInit: Check: transpose_batch_sequence is None?
PostInit->>PostInit: No → Skip warning
PostInit-->>User: Instance created (BSHD format, no warning)
|
There was a problem hiding this comment.
2 files reviewed, 2 comments
Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format
| def __post_init__(self): | ||
| # TODO(KshitijLakhani): Remove warning in TransformerEngine v2.13 | ||
| warnings.warn( | ||
| "transpose_batch_sequence defaults to False in DotProductAttention starting" | ||
| " TransformerEngine v2.10" | ||
| ) | ||
| super().__post_init__() |
There was a problem hiding this comment.
logic: Warning fires unconditionally on every DotProductAttention instantiation, even when users haven't relied on the old default. Should only warn when user's code would break (when they didn't explicitly set transpose_batch_sequence and need the old behavior).
| def __post_init__(self): | |
| # TODO(KshitijLakhani): Remove warning in TransformerEngine v2.13 | |
| warnings.warn( | |
| "transpose_batch_sequence defaults to False in DotProductAttention starting" | |
| " TransformerEngine v2.10" | |
| ) | |
| super().__post_init__() | |
| def __post_init__(self): | |
| # TODO(KshitijLakhani): Remove warning in TransformerEngine v2.13 | |
| # Only warn if behavior might have changed for existing code | |
| if not hasattr(self, '_transpose_batch_sequence_explicitly_set'): | |
| warnings.warn( | |
| "transpose_batch_sequence defaults to False in DotProductAttention starting" | |
| " TransformerEngine v2.10. Explicitly set transpose_batch_sequence=True to" | |
| " preserve old behavior.", | |
| FutureWarning, | |
| stacklevel=2 | |
| ) | |
| super().__post_init__() |
There was a problem hiding this comment.
Using a FutureWarning and stacklevel=2 make sense to me so will add them.
Finding a way to warn users only if transpose_batch_sequence has been explicitly set seems appealing and was considered but it may require relying on an additional flag _transpose_batch_sequence_explicitly_set as mentioned above or maybe I could do something like this (which also isnt very appealing)
- Set transpose_batch_sequence default as None/sentinel object and in DPA,
- Check to see if transpose_batch_sequence is set to None/sentinel object and then set default value to false else use the provided value
- If default, warn else do not warn
This may probably populate the auto generated docs with a not so clear idea
There was a problem hiding this comment.
@mgoldfarb-nvidia your thoughts on this ? - I think it's unnecessary complexity (and less clear a solution) for code that will be thrown away in the future (couple release down the line at most - although Charlene thinks one is enough as we'll be adding a note for this in the release notes) , even though I agree that the warnings could be unnecessary nuisance for users.
The WAR does not seem to elegant. I'm inclined to bite the bullet on this one but would like your thoughts
Thanks !
There was a problem hiding this comment.
UPDATE: I added a commit to set transpose_batch_sequence to None as "default" and in post_init I then set the true default to False and send out a warning. This ensures lesser spamming to an unintended audience
efe1a7f to
e72c901
Compare
|
/te-ci jax |
| @@ -589,7 +588,7 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 13, | |||
| "execution_count": null, | |||
There was a problem hiding this comment.
I think these notebook changes are just from running this notebook to test it, right? Lmk if these changes are anything functional, otherwise let's remove from this PR. Thanks!
There was a problem hiding this comment.
Thanks for catching it
My bad. Fixed it!
Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
…instead of SBHD Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
…or transpose_batch_sequence Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
5a45484 to
76eafb7
Compare
for more information, see https://pre-commit.ci
|
/te-ci jax |
|
CI pipeline passes: 38908560 (jobs had to be re-run) |
|
LGTM pending CI, thanks! I checked and MaxText is already setting transpose_batch_sequence to False explicitly so no changes will be needed there: https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/attention_op.py#L1400 |
|
@ksivaman just to bring to your attention for release documentation / breaking change documentation |
* Make BSHD default for Unfused DPA, DPA and MHA in TE JAX Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com> * Remove explicit transpose_batch set for BSHD for DPA in JAX quickstart Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com> * Add warnings in DPA and MHA to warn users of change defaults to BSHD instead of SBHD Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com> * Minimize the scope of when to trigger warnings for changed defaults for transpose_batch_sequence Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Description
Maintaining consistency across layers by setting BSHD as the default type for DPA, Unfused DPA and MHA to match the already existing defaults of BSHD in TransformerLayer and Fused DPA
Type of change
Changes
MHA, DPA (and Unfused DPA) API changes to accept BSHD as default
TE JAX tuts changed to reflect the same
In user facing APIs, DPA and MHA, I have set the default to be None. I set transpose_batch_sequence to False in post_init if transpose_batch_sequence is None.
I am basically use None defaults as way to indicate to the post init state whether the user is explicitly passing the param or not. Three cases can arise:
Checklist: