# Attention Is All You Need!

The core idea behind Transformer models is the attention mechanism [[1]](https://arxiv.org/abs/1706.03762). It identifies the correlation between words, selects the most important parts of the sentence to focus on, and captures meaningful patterns and dependencies in the data. A typical attention mechanism looks like this, where the pre-softmax operations can be scaling, bias and/or masking, and the post-softmax operation is usually dropout.

<figure align="center">
<img src="attn.png" width="70%">
<figcaption> Figure 1: Dot product attention. </figcaption>
</figure>

[Transformer Engine](https://github.com/NVIDIA/TransformerEngine.git) supports the calculation of dot product attention in three frameworks, [PyTorch](https://github.com/pytorch/pytorch), [JAX](https://github.com/google/jax) and [PaddlePaddle](https://github.com/PaddlePaddle/Paddle). The API for dot product attention in each framework is,
- [transformer_engine.pytorch.DotProductAttention](../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention)
- [transformer_engine.jax.flax.DotProductAttention](../api/jax.rst#transformer_engine.jax.flax.DotProductAttention)
- [transformer_engine.paddle.DotProductAttention](../api/paddle.rst#transformer_engine.paddle.DotProductAttention)

## 1. Attention Backends

Transformer Engine provides multiple backends for each supported framework. While the framework-native implementations provide a robust baseline, the fused, GPU-optimized backends offer higher performance. For example, the flash-attention and cuDNN attention backends for PyTorch.

The available attention backends in Transformer Engine are listed below, where the framework-native implementations are often referred to as "unfused" and the more optimized backends are "fused" or "flash". We will discuss the difference between the latter two in the next sub-section.

| Framework | Backend (Module Name) | Module Location |
| :-------- | :-------------------- | :-------------- |
| PyTorch   | cuDNN attention (`FusedAttention`)<br> flash-attention (`FlashAttention`)<br> PyTorch-native attention (`UnfusedDotProductAttention`) | [transformer_engine.pytorch.attention](../../transformer_engine/pytorch/attention.py)      |
| JAX       | cuDNN attention (`_FusedDotProductAttention`)<br> JAX-native attention (`_UnfusedDotProductAttention`)                                | [transformer_engine.jax.flax.transformer](../../transformer_engine/jax/flax/transformer.py)   |
| PaddlePaddle    | cuDNN attention (`_te_forward`)<br> PaddlePaddle-native attention (`_pd_forward`)                                                           | [transformer_engine.paddle.layer.attention](../../transformer_engine/paddle/layer/attention.py) |

### 1.1 Flash vs. Non-Flash

The name "flash attention" comes from the paper "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" [[2]](https://arxiv.org/abs/2205.14135), but both our flash-attention and cuDNN attention backends (two of its three sub-backends) use the flash algorithm proposed in that paper.

The quadratic complexity of attention presents siginificant challenges when scaling Transformer models to longer sequences. For a sequence length `N`, the attention calculation requires `O(N^2)` time and memory. The flash algorithm was proposed to reduce the scaling pattern from `O(N^2)` to `O(N)`, greating improving its computational efficiency and memory utilization. Compared to the standard, non-flash algorithm, it employs these two techniques.

- **Tiling:** Instead of processing the query, key, value tensors in one single step, the flash algorithm decomposes the data into tiles, with the tile size determined by the shared memory size on the hardware. It then computes the softmax one tile at a time before combining all the results together in a separate step. This tiling technique significantly reduces the memory footprint as well as the I/O traffic between global memory and shared memory, a performance bottleneck in the non-flash algorithm.

- **Recomputation:** The flash algorithm stores the softmax normalization factors (linear to `N`) instead of the full softmax matrix (quadratic to `N`). The normalization factors are then used to recalculate the attention scores in the backward pass. Despite the increased computation, the savings on reads and writes between global memory and shared memory still help alleviate the pressure on bandwidth and provide higher runtime efficiency. The memory footprint is reduced to `O(N)` as well.

### 1.2 flash-attention

The flash-attention backend is implemented based on the [`flash-attn`](https://github.com/Dao-AILab/flash-attention) package. Written by the same authors as [[2]](https://arxiv.org/abs/2205.14135), `flash-attn` uses the flash algorithm.

`flash-attn` is open-source and thanks to its [performance improvements](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#performance) over framework-native implementations, it has been widely adopted in the Transformer community.

`flash-attn` offers only PyTorch interfaces and it has been integrated to the `transformer_engine.pytorch.attention.FlashAttention` module in Transformer Engine. `FlashAttention` wraps around `flash-attn` and provides some miscellaneous functionalities such as converting an `attention_mask` tensor to the cumulative sequence lengths `cu_seqlens` for `padding` mask.

Transformer Engine regularly updates its `flash-attn` dependency (`flash-attn` in [setup.py](../../setup.py)). As of v1.7, Transformer Engine supports `flash-attn` 2.0.6+.

### 1.3 cuDNN Attention

The cuDNN attention backend offers another high-performance attention implementation. Developed at NVIDIA, it requires [cuDNN](https://developer.nvidia.com/cudnn) and [cudnn-frontend](../../3rdparty/cudnn-frontend) to run. Offering multiple sub-backends, cuDNN attention's sub-backends 1 and 2 are based on the flash algorithm, just like flash-attention is.

| Sub-Backend |  Algorithm | Precision | Sequence Length | Architecture | Docs |
| :---------- | :--------- | :-------- | :-------------- | :----------- | :--- |
| 0 | Non-Flash | BF16/FP16       | <=512       | sm80, 90 | [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-attention-fprop) |
| 1 | Flash     | BF16/FP16       | Any         | sm80+    | [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-flash-attention-fprop),<br>[cudnn-frontend](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention) |
| 2 | Flash     | FP8             | cuDNN pre-9.0: <=512<br>cuDNN 9.0+: Any | cuDNN pre-9.0: sm90<br>cuDNN 9.0+:  sm90+ | cuDNN 9.0+: [cudnn-frontend](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention-fp8) |

As of cuDNN 9.0 and `flash-attn` 2.4.2, our cuDNN attention and flash-attention backends have the following few differences.

- flash-attention supports PyTorch only, while cuDNN attention supports all three frameworks (see the table in Attention Backends above).
- flash-attention supports BF16 and FP16, while cuDNN attention also supports FP8 (sub-backend 2).
- flash-attention supports `bshd` and `thd` formats without transposes, `sbhd` with transposes, while cuDNN attention supports `bshd`, `sbhd` and `thd` all without transposes. Please see section 3.1 for details on QKV layouts and formats.
- flash-attention does not support `post_scale_bias`, while cuDNN attention does.
- flash-attention supports sliding window attention, paged attention, while cuDNN attention does not.
- flash-attention uses bottom right diagonal for causal mask in cross attention, while cuDNN attention uses top left diagonal.
- flash-attention has more performance advantage on Ampere architectures, while cuDNN attention has more on Hopper architectures.

For performance benchmarking, an example script [benchmarks/attention/benchmark_attention.py](../../benchmarks/attention/benchmark_attention.py) is provided. Users can modify the `ModelConfig` in the script to run the configurations of their interest, for example,

In [None]:
model_configs = {
    #   test:             b,  h, hg,   d,   sq,  skv,   p,     mask,              bias
    "test_0": ModelConfig(2, 16, 16,  64,  512,  512, 0.0, "no_mask",         "no_bias"), # short seq
    "test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0,  "causal",         "no_bias"), # longer seq, mask
    "test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0,  "causal", "post_scale_bias"), # bias
    "test_3": ModelConfig(2, 32,  8, 128, 8192, 8192, 0.0,  "causal",         "no_bias"), # GQA
}

The script collects the runtimes of flash-attention and cuDNN attention in PyTorch for `num_iters` iterations for each config. The average times and speedups are listed in a table - if a backend is eligible (does not have support for the specific config), the times/speedup will be 0.

In [1]:
!cd ../../benchmarks/attention/ && python benchmark_attention.py

Capture range started in the application.
Capture range ended in the application.
Generating '/tmp/nsys-report-d934.qdstrm'
Generated:
    /code/fmha/github3/pr-attn-doc/TransformerEngine/benchmarks/attention/prof_test_0.nsys-rep
Capture range started in the application.
Capture range ended in the application.
Generating '/tmp/nsys-report-1ba1.qdstrm'
Generated:
    /code/fmha/github3/pr-attn-doc/TransformerEngine/benchmarks/attention/prof_test_1.nsys-rep
Capture range started in the application.
Capture range ended in the application.
Generating '/tmp/nsys-report-4ca7.qdstrm'
Generated:
    /code/fmha/github3/pr-attn-doc/TransformerEngine/benchmarks/attention/prof_test_2.nsys-rep
Capture range started in the application.
Capture range ended in the application.
Generating '/tmp/nsys-report-1c15.qdstrm'
Generated:
    /code/fmha/github3/pr-attn-doc/TransformerEngine/benchmarks/attention/prof_test_3.nsys-rep
        cuDNN fwd+bwd (ms)  flash-attn fwd+bwd (ms)  cuDNN vs flash speedup
test

## 2. Backend Selection

Transformer Engine selects the appropriate backend (and sub-backend) based on the availability and performance of the backends (and sub-backends).

The backend availability is determined by factors such as user input, software version, and GPU architecture. Examples of these factors include (but are not limited to) the sequence length, number of heads, head size, mask type, bias type, training mode, attention type, MQA/GQA or not, `flash-attn`/cuDNN versions, and GPU architecture during runtime. 

When there are multiple backends available, the performance of the backend becomes the deciding factor in the selection logic. The two general rules are, 1) select the fused implementation over unfused, and 2) select the more performant one based on other offline heuristics which are obtained through benchmarking a range of commonly-used model configs.

The selection order is subject to change as we monitor the different backends' performance.

| Framework | Selection Order                                                                                                                              |
| :-------- | :--------------------- |
| PyTorch   | sm90: cuDNN attention > flash-attention > PyTorch-native attention<br>sm80: flash-attention > cuDNN attention > PyTorch-native attention |
| JAX       | cuDNN attention > JAX-native attention |
| PaddlePaddle    | cuDNN attention > PaddlePaddle-native attention |

### 2.1 Debug Information

To find out which backend (or sub-backend) is used during runtime, please run with `NVTE_DEBUG=1`. For example,
```
        [DotProductAttention]: using flash-attn 2.4.2
        [DotProductAttention]: using cuDNN attention (backend 0)
        [DotProductAttention]: using cuDNN attention (backend 1)
        [DotProductAttention]: using cuDNN attention (backend 2)
        [DotProductAttention]: fp8_dpa=False, fp8_mha=False, FP8_BWD_ONLY=True
```
To file a bug, please run with `NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2` to provide more details about the model configuration. For example,
```
        [DotProductAttention]: using cuDNN attention (backend 1)
        [DotProductAttention]: dtype=torch.bfloat16, b=2, s_q=2048, s_kv=2048, h_q=16, h_kv=2, d=64, qkv_layout='bshd', mask_type='padding', bias_type='post_scale_bias', bias_shape='1hss', dropout=0.1, is_training=True, context_parallel=True, sm=80, cudnn_version=9.0
```

### 2.2 User Control

Transformer Engine selects the most appropriate backend (and sub-backend) for user's model configuration and runtime environment. If there is a performance or convergence issue, a few environment variables are provided for users to experiment with different backends (and sub-backends).

The following two environment variables offer control over flash-attention and cuDNN attention in PyTorch.
```
        NVTE_FLASH_ATTN = 0 # disables flash-attention; default = 1
        NVTE_FUSED_ATTN = 0 # disables cuDNN attention; default = 1
```

This variable does *not* offer control but provides a way for users to express their preference over cuDNN attention sub-backends. The elected sub-backend will only be used when it's eligible, i.e. has support for the provided user input and environment.
```
        NVTE_FUSED_ATTN_BACKEND = 0/1/2 # user perference of cuDNN attention sub-backend when available
```

cuDNN attention sub-backend 1 also offers two paths: workspace optimization path and non-workspace optimization path. The workspace optimization path trades memory for performance, i.e. it requires more global memory to run but provides better performance in some cases. This path is only available on Hopper architectures, and it is turned on when the cuDNN-estimated workspace size (`batch_size x seqlen_q x seqlen_kv`) is <= 256MB. Users can control the limit using the following environment variable. Please be aware of the Out-Of-Memory risks when doing so.
```
# CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT
# - unset: enables workspace optimization when required workspace is <= 256MB
#          or when bias gradient needs to be computed
# -     n: enables workspace optimization when required workspace is <= n bytes
# -    -1: enables workspace optimization always
# -     0: disables workspace optimization always
```

## 3. Backend Support

Transformer Engine's attention backends support the commonly-used features such as self/cross attention types, dropout and FP16/BF16 precisions, but they have varying support for other features. As of Transformer Engine v1.7, the support matrix is,

| Attention Backend | Precision      | Architecture | Sliding Window Attention | MQA/GQA | Context Parallelism | Deterministic |
| :---------------- | :------------- | :----------- | :----------------------- | :------ | :------------------ | :------------ |
| cuDNN attention<br>(PyTorch, JAX, PaddlePaddle)            | BF16, FP16, FP8  |  sm80+ | No  | Yes | No (`bshd`, `sbhd`, `thd`) | Sub-backend 0, 2: Yes<br>Sub-backend 1: Yes, if<br>workspace optimization path |
| flash-attention<br>(PyTorch)           | BF16, FP16      |  sm80+ | Yes | Yes | Yes (`bshd`)                      | Yes, if `deterministic=True`                                                                                    |
| Framework-native attention<br>(PyTorch, JAX, PaddlePaddle) | BF16, FP16, FP32 |  Any   | No,<br>unless used as a mask  | Yes | No                                  | Yes |


### 3.1 QKV Layout

Transformer Engine supports various memory layouts for the query, key and value tensors. To help categorize the layouts, Transformer Engine has defined 15 QKV layouts, which can be grouped into 3 QKV formats and 5 QKV layout groups.

| qkv_layout        | qkv_layout_group=`3hd` | qkv_layout_group=`h3d` | qkv_layout_group=`hd_2hd` | qkv_layout_group=`hd_h2d` | qkv_layout_group=`hd_hd_hd` |
| :--------------- | :-------------------- | :----- | :---------- | :---------- | :-------------- |
| qkv_format=`sbhd` | `sb3hd`                | `sbh3d` | `sbhd_sb2hd` | `sbhd_sbh2d` | `sbhd_sbhd_sbhd` |
| qkv_format=`bshd` | `bs3hd`                | `bsh3d` | `bshd_bs2hd` | `bshd_bsh2d` | `bshd_bshd_bshd` |
| qkv_format=`thd`  | `t3hd`                 | `th3d`  | `thd_t2hd`   | `thd_th2d`   | `thd_thd_thd`    |

Here, the notation system is, `b` batch size, `s` sequence length, `h` number of heads, `d` head dimension, and `t` total number of tokens in a batch, `t = sum(s_i) for i in 0,...,b-1`. To help understand the different layouts, here are a few examples of the meaning of the layout.

- `sb3hd`: tensors are sequence first; `q`, `k` and `v` are in one memory space; they are interleaved at the `h * d` dimension; `q, k, v = [qkv[:,:,i,:,:] for i in range(3)]`.
- `bshd_bsh2d`: tensors are batch first; `q`, `k` and `v` are in two memory spaces `q` and `kv`; `k` and `v` are interleaved at the `d` dimension inside the `kv` space; `q` is contiguous and `k, v = [kv[:,:,:,i,:] for i in range(2)]`. The second `s` can be different from the first `s` in the case of cross attention, and same for `h` when MQA/GQA is employed.
- `thd_thd_thd`: tensors have variable sequence lengths in the batch; `q`, `k` and `v` are in three memory spaces; they are not interleaved in any way and are all contiguous.

We group these 15 QKV layouts into 3 QKV formats and 5 QKV layout groups to help simplify the code when multiple layouts share the same properties.


Transformer Engine supports all 15 layouts in PyTorch, and 3 layouts, `bs3hd`, `bshd_bs2hd` and `bshd_bshd_bshd`, in JAX and Paddle. A utility function in PyTorch is [transformer_engine.pytorch.attention._get_qkv_layout](../../transformer_engine/pytorch/attention.py) to help users figure out what `qkv_layout` they have.

Transformer Engine, as of v1.7, has the following support matrix for different QKV formats:

| Backend | Supported QKV Formats |
| :--------------- | :-------------------- |
| TriDao attention | `bshd`, `sbhd`, `thd` (`sbhd` requires transpose operations) |
| cuDNN attention  | `bshd`, `sbhd`, `thd`  |
| Framework-native attention | `bshd`, `sbhd` (`sbhd` requires transpose operations) |

In Pytorch, when RoPE is employed, the QKV layout may change. If the initial QKV layout is not in the `hd_hd_hd` QKV group, `transformer_engine.pytorch.attention._get_qkv_layout` will convert it to the `hd_hd_hd` group. For example, `sbh3d` will be converted to `sbhd_sbhd_sbhd`.

### 3.2 Attention Mask

Transformer Engine supports 5 mask types: `no_mask`, `padding`, `causal`, `padding_causal` (equivalent to `causal_padding`), and `arbitrary`. All masks are defined as `True - masking the corresponding element out` and `False - including the element in attention calculation`.

The support matrix for attention mask is:

| Backend          | Supported Mask Types  | Require Mask Tensor |
| :--------------- | :-------------------- | :------------------ |
| TriDao attention | `no_mask`, `causal`, `padding`, `padding_causal` | `no_mask`, `causal`: no<br>`padding`, `padding_causal`: yes if `cu_seqlens` not provided|
| cuDNN attention  | `no_mask`, `causal`, `padding`, `padding_causal` | No |
| Framework-native attention | `no_mask`, `causal`, `arbitrary` | `no_mask`, `causal`: no<br>`arbitrary`: yes |

For `padding` and `padding_causal` mask types, an `attention_mask` tensor should be provided for TriDao attention. For self-attention, `attention_mask` should be one tensor in shape `[batch_size, 1, 1, seqlen_q]`, and for cross-attention, it should be a list of two tensors in shapes `[batch_size, 1, 1, seqlen_q]` and `[batch_size, 1, 1, seqlen_kv]`. For example,

Alternatively, users can pass in `cu_seqlens` tensors. In the case where `cu_seqlens` and `attention_mask` are both passed in, Transformer Engine will pick `cu_seqlens` to save extra compute from `get_cu_seqlens()`. Example:

For `qkv_format=thd`, if `max_seqlen_q` and `max_seqlen_kv` are not present, Transformer Engine will extract them from the `q`, `k`, `v` tensors. This may cost a GPU-CPU copy as well as a synchronization operation, so it's recommended that users set `max_seqlen_q` and `max_seqlen_kv` when running with `thd` layouts.

As of v1.7, cuDNN attention does not support `Arbitrary` masks. However, users can try the `post_scale_bias` path to apply the mask. An example script to convert the mask to a bias and call cuDNN attention is [here](./arbitrary_mask_to_bias.py). This path is more performant than the unfused path.

Since v2.1, `flash-attention` changed its implementation for `causal` mask in cross attention (see [here](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag)). Please note that in this case, `flash-attention` uses the bottom right diagonal while cuDNN attention uses the top left.


### 3.3 Attention Bias

Transformer Engine supports 4 bias types: `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes).

| Backend | Bias Type | Bias Shape | Bias Dtype | Architecture |
| :------ | :-------- | :--------- | :--------- | :----------- |
| TriDao attention           | `no_bias`, `ALiBi` (with slopes) | NA | AliBi slopes: FP32 | sm80+ |
| cuDNN attention            | `no_bias`, `post_scale_bias`, `ALiBi` (without slopes) | `post_scale_bias`: BHSS, 1HSS, B1SS, 11SS for forward and 1HSS for backward | `post_scale_bias`: same as data dtype<br>ALiBi slopes: FP32 | cuDNN 8.9.6+: sm90<br>cuDNN 9.0+: sm80, 90 |
| Framework-native attention | `no_bias`, `pre_scale_bias`, `post_scale_bias` | `post_scale_bias`: BHSS, 1HSS, B1SS, 11SS | `post_scale_bias`: same as data dtype | sm80+ |

TriDao attention enables `ALiBi` bias by user passing in a `alibi_slopes` tensor. This can be the default slopes that come with vanila ALiBi, or custom slopes from the user. On the other hand, cuDNN attention supports ALiBi by taking a boolean flag rather than a slopes tensor. As of v8.9.6, it only supports vanila ALiBi calculations.

The framework-native attention backends do not explicitly support ALiBi, however, users can generate a `post_scale_bias` equivalent to the ALiBi bias. An example of this is the `_get_alibi()` in `transformer_engine.pytorch`. Example in test.

### 3.4 FP8 Attention

Transformer Engine supports FP8 attention on the C level () and Python level (). In v1.6, it also added support on the framework level (PyTorch only).

fp8_dpa, fp8_mha
extra states
te.sequencial

### 3.5 Other Features

Users can also mix the forward and backward of cuDNN attention and TriDao attention, by setting `NVTE_FUSED_ATTN_USE_FAv2_BWD=1`. This helps when TriDao attention backward is faster than cuDNN attention backward.