Skip to content

Conversation

@sudhakarsingh27
Copy link
Collaborator

Description

Follow up for #2584 to add and verify support for "deterministic" fp8 dpa/mha cudnn attention kernels

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • plumb deterministic argument through fused_attn_fp8.cu.
  • adjust filters in pytorch/attention/dot_product_attention/utils.py to allow fp8 + deterministic kernels on SM100
  • edit tests in test_attention.py to check fp8 with deterministic=True

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27 sudhakarsingh27 self-assigned this Jan 24, 2026
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L1

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 24, 2026

Greptile Summary

Extends deterministic FP8 attention support from FP16/BF16 (PR #2584) to FP8 data types on SM100 (Blackwell) architecture. The implementation threads the deterministic parameter through the backward pass call chain from tests to cuDNN backend, enables FP8 deterministic kernels in backend selection logic for SM100+, and removes hardcoded NVTE_ALLOW_NONDETERMINISTIC_ALGO=1 from tests to allow proper parameterization.

Key changes:

  • Plumbed deterministic parameter through C++ API (fused_attn.cpp) to FP8 CUDA implementation (fused_attn_fp8.cu)
  • Updated backend filter in utils.py to allow FP8 deterministic on SM100+ (previously disabled on all architectures)
  • Modified tests to parameterize deterministic flag instead of hardcoding non-deterministic behavior
  • Sets cuDNN's set_deterministic_algorithm() when cuDNN >= 9.19.0

The version check at fused_attn_fp8.cu:2213 may need verification - it uses 91900 (cuDNN 9.19.0) while the related PR #2584 mentions 9.18.1+ requirement.

Confidence Score: 4/5

  • Safe to merge with minor version requirement clarification recommended
  • Implementation correctly threads deterministic parameter through the call stack and follows the established pattern from PR [Common] Enable determinism for cuDNN >= 9.18.1 on Blackwell #2584. Backend filtering logic properly restricts FP8 deterministic to SM100+. However, version check uses 91900 (cuDNN 9.19.0) instead of 91810 (9.18.1) mentioned in PR description, warranting clarification before merge.
  • Check transformer_engine/common/fused_attn/fused_attn_fp8.cu:2213 for correct cuDNN version requirement

Important Files Changed

Filename Overview
tests/pytorch/attention/test_attention.py Added deterministic parameter to FP8 attention tests, replacing hardcoded environment variable with parameterized test values
transformer_engine/common/fused_attn/fused_attn_fp8.cu Added deterministic parameter to FP8 backward implementation, sets cuDNN deterministic algorithm flag for version >= 9.19.0
transformer_engine/pytorch/attention/dot_product_attention/utils.py Updated backend selection filter to allow FP8 deterministic kernels on SM100+ (Blackwell)

Sequence Diagram

sequenceDiagram
    participant Test as test_mha_fp8_vs_f16/<br/>test_dpa_fp8_vs_f16
    participant Utils as utils.py:<br/>get_attention_backend
    participant CPP as fused_attn.cpp:<br/>nvte_fused_attn_bwd*
    participant FP8 as fused_attn_fp8.cu:<br/>fused_attn_fp8_bwd
    participant Impl as fused_attn_fp8.cu:<br/>fused_attn_fp8_bwd_impl_v1
    participant cuDNN as cuDNN Backend

    Note over Test: deterministic param<br/>added to tests
    Test->>Utils: check backend with<br/>deterministic flag
    Note over Utils: Filter: Allow FP8+deterministic<br/>only on SM100+ (arch >= 10.0)
    Utils-->>Test: backend available
    
    Test->>CPP: nvte_fused_attn_bwd_*<br/>(deterministic)
    Note over CPP: QKV packed/<br/>KV packed/<br/>separate paths
    CPP->>FP8: fused_attn_fp8_bwd<br/>(deterministic)
    FP8->>Impl: fused_attn_fp8_bwd_impl_v1<br/>(deterministic)
    
    Note over Impl: Check cuDNN version
    alt cudnn_runtime_version >= 91900
        Impl->>cuDNN: set_deterministic_algorithm<br/>(deterministic)
    else version < 91900
        Note over Impl: deterministic flag ignored
    end
    
    Impl->>cuDNN: execute backward pass
    cuDNN-->>Impl: gradients
    Impl-->>FP8: return
    FP8-->>CPP: return
    CPP-->>Test: return
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +2213 to +2215
if (cudnn_runtime_version >= 91900) {
sdpa_backward_options.set_deterministic_algorithm(deterministic);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Version check uses 91900 (cuDNN 9.19.0), but related PR #2584 and description mention 9.18.1+ requirement. Should this be 91810 instead?

Suggested change
if (cudnn_runtime_version >= 91900) {
sdpa_backward_options.set_deterministic_algorithm(deterministic);
}
if (cudnn_runtime_version >= 91810) {
sdpa_backward_options.set_deterministic_algorithm(deterministic);
}

Is there a specific reason FP8 requires cuDNN 9.19.0+ while FP16/BF16 only needs 9.18.1+?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants