# Attention Is All You Need!

The core idea behind Transformer models is the attention mechanism [[1]](https://arxiv.org/abs/1706.03762). It identifies correlations between words, selects the most important parts of the sentence to focus on, and captures meaningful patterns and dependencies in the data. A typical attention mechanism looks like this, where the pre-softmax operations can be scaling, bias and/or masking, and the post-softmax operation is usually dropout.

<figure align="center">
<img src="attn.png" width="70%">
<figcaption> Figure 1: Dot product attention. </figcaption>
</figure>

[Transformer Engine](https://github.com/NVIDIA/TransformerEngine.git) supports the calculation of dot product attention in three frameworks, [PyTorch](https://github.com/pytorch/pytorch), [JAX](https://github.com/google/jax) and [PaddlePaddle](https://github.com/PaddlePaddle/Paddle), and their APIs are,
- [transformer_engine.pytorch.DotProductAttention](../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention)
- [transformer_engine.jax.flax.DotProductAttention](../api/jax.rst#transformer_engine.jax.flax.DotProductAttention)
- [transformer_engine.paddle.DotProductAttention](../api/paddle.rst#transformer_engine.paddle.DotProductAttention)

## 1. Attention Backends

For each supported framework, Transformer Engine provides more than one backends: while the framework-native backends provide a robust baseline, the more fused implementations offer higher performance, such as the [flash-attention](https://github.com/Dao-AILab/flash-attention) and [cuDNN attention](https://github.com/NVIDIA/cudnn-frontend).

A list of the available attention backends is below - note that flash-attention is only available in PyTorch, while cuDNN attention is available in all three frameworks.

| Framework | Backend (Module Name) | Module Location |
| :-------- | :-------------------- | :-------------- |
| PyTorch   | cuDNN attention (`FusedAttention`)<br> flash-attention (`FlashAttention`)<br> PyTorch-native attention (`UnfusedDotProductAttention`) | [transformer_engine.pytorch.attention](../../transformer_engine/pytorch/attention.py)      |
| JAX       | cuDNN attention (`_FusedDotProductAttention`)<br> JAX-native attention (`_UnfusedDotProductAttention`)                                | [transformer_engine.jax.flax.transformer](../../transformer_engine/jax/flax/transformer.py)   |
| PaddlePaddle    | cuDNN attention (`_te_forward`)<br> PaddlePaddle-native attention (`_pd_forward`)                                                           | [transformer_engine.paddle.layer.attention](../../transformer_engine/paddle/layer/attention.py) |

### 1.1 Flash vs. Non-Flash

The name "flash attention" has been a buzz word in the world of attention implementations. But it is important to clarify that out of the backends listed above, flash-attention and cuDNN attention (two of its three sub-backends) are both flash attention, even though ther names might indicate differently.

The definition of flash attention is based on whether the implementation is using the flash algorithm or not.

We all know that it is a challenge to deal with the quadratic time and space complexity of attention - it requires `O(N^2)` runtime and on-device memory to calculate the dot product attention of query `q`, key `k`, values `v` tensors with sequence length `N`. As we scale up to longer contexts, our training/inference time, as well as memory footprint, increase quadratically to the sequence length.

The standard, non-flash attention algorithm has been to process the entire `q`, `k`, `v` tensors in one single step and consume `O(N^2)` memory on the GPU, untill the more efficient, less memory-demanding flash algorithm came along.

The flash algorithm was proposed in [[2]](https://arxiv.org/abs/2205.14135), and compared to the standard, non-flash algorithm, it employs two techniques to improve the scaling pattern from quadratic to linear, allowing for a much wider range of sequence length in LLMs.

- Tiling. The flash algorithm decomposes the input data into several tiles, with the tile size flexibly determined by the shared memory size on the hardware. It calculates the softmax for each of the tiles and then combines the results together. This reduces the memory footprint as well as the traffic between the global memory and shared memory.

- Recomputation. The flash algorithm only stores the softmax normalization factors (linear to sequence length), instead of the full softmax matrix (qudratic to sequence length). This, again, saves the amount of writes/reads between global memory and shared memory, and reduces the amount of global memory required. Some recomputation is needed, however, to reproduce the attention scores in the backward. But this is a small price to pay compared to the time and memory savings from the two techniques above.

### 1.2 flash-attention

[flash-attention](https://github.com/Dao-AILab/flash-attention) was implemented by the same group of researchers who proposed the flash algorithm [[2]](https://arxiv.org/abs/2205.14135).

It is open-source and has significant [performance](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#performance) advantage compared to the framework-native implementations in the community.

It only supports the PyTorch framework and has been integrated into Transformer Engine in the PyTorch module: `transformer_engine.pytorch.attention.FusedAttention`. 

`FusedAttention` is a backend of `transformer_engine.pytorch.DotProductAttention`, and it wraps around flash-attention calls, providing a few miscellaneous functionalities such as converting the `attention_mask` to `cu_seqlens` in the case of `padding` mask.

Transformer Engine updates its flash-attention dependency (see `flash-attn` in [setup.py](../../setup.py)) as development evolves. As of v1.7, Transformer Engine supports flash-attention 2.0.6+.

### 1.3 cuDNN Attention

[cuDNN attention](https://github.com/NVIDIA/cudnn-frontend) is the backbone of several fused attention backends mentioned above. It is developed at NVIDIA and has competing (and in many cases, superior) performances, compared to flash-attention. 

It requires [cuDNN](https://developer.nvidia.com/cudnn) and [cudnn-frontend](../../3rdparty/cudnn-frontend) to run, and Transformer Engine has support for it in all three frameworks.

It has several sub-backends as cuDNN evolves, and the sub-backends 1 and 2 are both based on the flash algorithm as flash-attention is.

| Sub-Backend |  Algorithm | Precision | Sequence Length | Architecture | Docs |
| :---------- | :--------- | :-------- | :-------------- | :----------- | :--- |
| 0 | Non-Flash | BF16/FP16       | <=512       | sm80, 90 | [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-attention-fprop) |
| 1 | Flash     | BF16/FP16       | Any         | sm80+    | [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-flash-attention-fprop),<br>[cudnn-frontend](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention) |
| 2 | Flash     | FP8             | cuDNN pre-9.0: <=512<br>cuDNN 9.0+: Any | cuDNN pre-9.0: sm90<br>cuDNN 9.0+:  sm90+ | cuDNN 9.0+: [cudnn-frontend](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention-fp8) |

To compare cuDNN attention's performance against flash-attention's, you can use this following script by modifying the `ModelConfig`. For example,

In [7]:
!cd ../../../
!cd TransformerEngine/benchmark/attention/
!ModelConfig
!bash ./run.sh

/bin/bash: line 1: cd: TransformerEngine/benchmark/attention/: No such file or directory
/bin/bash: line 1: ModelConfig: command not found
bash: ./run.sh: No such file or directory


In [9]:
import pandas as pd

In [13]:
df = pd.DataFrame()
df.to_csv('timing.csv')
df = pd.read_csv('timing.csv')
df

Unnamed: 0.1,Unnamed: 0


## 2. Backend Selection

Given the various backends and sub-backends, Transformer Engine selects the most appropriate one based on both user input and backend performance heuristics. It is first determined whether a backend is eligible based on user input, such as sequence length, number of heads, head size, mask type, and bias type. Runtime environment plays a role here too, since different `flash-attention` or cuDNN versions may have support different input parameters. When multiple backends/sub-backends are eligible, the performance heuristics determine which one of them to choose.

Generally, the fused versions of the implementation are more performant than the unfused versions. Also, based on our benchmarks for multiple commonly-used configs, on Hopper architectures, cuDNN attention (sub-backend 1) is faster than `flash-attention`, and on Ampere, slower. Our general selection order is therefore, 

| Framework | Selection Order                                                                                                                              |
| :--------- | :---------------------------------------------------------------------------------------------------------------------------------------------------- |
| PyTorch   | sm90: cuDNN attention > TriDao attention > PyTorch native implementation<br>sm80: TriDao attention > cuDNN attention > PyTorch native implementation |
| JAX       | cuDNN attention > JAX native implementation                                                                                                          |
| Paddle    | cuDNN attention > Paddle native implementation                                                                                                       |

As we monitor the performance of different backends/sub-backends, this order may change.

Usually users do not need to concern themselves with the selection logic. However, if convergence or performance issues arise, or simply out of curiousity, users can set `NVTE_DEBUG=1` to see which backend has been used exactly.
```
        [DotProductAttention]: using flash-attn 2.4.2
        [DotProductAttention]: using cuDNN attention (backend 0)
        [DotProductAttention]: using cuDNN attention (backend 1)
        [DotProductAttention]: using cuDNN attention (backend 2)
        [DotProductAttention]: fp8_dpa=False, fp8_mha=False, FP8_BWD_ONLY=True
```

If users need to file an issue with Transformer Engine, it's good to run with `NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2` and include the printed details in the issue as well.
```
        [DotProductAttention]: using cuDNN attention (backend 1)
        [DotProductAttention]: dtype=torch.bfloat16, b=2, s_q=2048, s_kv=2048, h_q=16, h_kv=2, d=64, qkv_layout='bshd',
                               mask_type='padding', bias_type='post_scale_bias', bias_shape='1hss', dropout=0.1,
                               is_training=True, context_parallel=True, sm=80, cudnn_version=9.0
```

Some environment variables are provided if users need to experiment with different backends:
```
        NVTE_FLASH_ATTN = 0 # disables flash-attention; default = 1
        NVTE_FUSED_ATTN = 0 # disables cuDNN attention; default = 1
        NVTE_FUSED_ATTN_BACKEND = 0/1/2 # informs Transformer Engine of user perference for cuDNN sub-backends
```
While `NVTE_FLASH_ATTN=0` and `NVTE_FUSED_ATTN=0` are forceful in turning a backend on or off, the `NVTE_FUSED_ATTN_BACKEND` environment variable only shows user perference. It takes effect only when that backend is an eligible one; if not, Transformer Engine will route the attention calculation to what it had determined as appropriate.


## 3. Backend Support

Different backends have different support for the commonly-used features in Transformer models. As of v1.7, all backends support self- and cross-attention, dropout, and BF16/FP16 precisions. But they vary in other features.

| Backend | Precision      | Architecture | Sliding Window Attention | MQA/GQA | Context Parallelism | Deterministic |
| :------ | :------------- | :----------- | :----------------------- | :------ | :------------------ | :------------ |
| cuDNN attention            | BF16/FP16/FP8  |  sm80+ | No  | Yes | No (`bshd`, `sbhd`, `thd`) | Sub-backend 0, 2: yes<br>Sub-backend 1: yes if workspace optimization path |
| TriDao attention           | BF16/FP16      |  sm80+ | Yes | Yes | Yes (`bshd`)                      | Yes if `deterministic=True`                                                                                    |
| Framework-native attention | BF16/FP16/FP32 |  Any   | No unless passed in as a mask  | Yes | No                                  | Yes |

The "workspace optimization" path in the table is a feature in cuDNN sub-backend 1. It trades memory for performance and is turned on by default, when the required workspace size (`batch_size x seqlen_q x seqlen_kv`) is <= 256MB. It provides 20-30% more performance than the non-workspace optimization path, and is deterministic. Users can control it by setting the following environment variable. Please be aware of the Out-Of-Memory risk while doing so.
```
# CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT
# - unset: enables workspace optimization when required workspace is <= 256MB
#          or when bias gradient needs to be computed
# -     n: enables workspace optimization when required workspace is <= n bytes
# -    -1: enables workspace optimization always
# -     0: disables workspace optimization always
``` 
The non-workspace optimization path for cuDNN sub-backend 1 is non-deterministic, and when `deterministic=True` is set for Transformer Engine's `DotProductAttention` module in PyTorch, it turns on the workspace optimization path as well.

### 3.1 QKV Layout

The query (`q`), key (`k`), value (`v`) tensors passed into Transformer Engine may be in various memory layouts. Transformer Engine had defined 15 layouts, 3 formats and 5 layout groups to faciliate this calculation. Here, `b` is the batch size, `s` sequence length, `h` number of heads, `d` head dimension, and `t` the total number of tokens in a batch, i.e. `t = sum(s_i) for i in 0,...,b-1`. A few examples of these layouts are,
- `sb3hd`: tensors are sequence first; `q`, `k` and `v` are in one memory space; they are interleaved at the `h * d` dimension; `q, k, v = [qkv[:,:,i,:,:] for i in range(3)]`.
- `bshd_bsh2d`: tensors are batch first; `q`, `k` and `v` are in two memory spaces `q` and `kv`; `k` and `v` are interleaved at the `d` dimension inside the `kv` space; `q` is contiguous and `k, v = [kv[:,:,:,i,:] for i in range(2)]`. The second `s` can be different from the first `s` in the case of cross attention, and same for `h` when MQA/GQA is employed.
- `thd_thd_thd`: tensors have variable sequence lengths in the batch; `q`, `k` and `v` are in three memory spaces; they are not interleaved in any way and are all contiguous.

We group these 15 QKV layouts into 3 QKV formats and 5 QKV layout groups to help simplify the code when multiple layouts share the same properties.

| qkv_layout        | qkv_layout_group=`3hd` | qkv_layout_group=`h3d` | qkv_layout_group=`hd_2hd` | qkv_layout_group=`hd_h2d` | qkv_layout_group=`hd_hd_hd` |
| :--------------- | :-------------------- | :----- | :---------- | :---------- | :-------------- |
| qkv_format=`sbhd` | `sb3hd`                | `sbh3d` | `sbhd_sb2hd` | `sbhd_sbh2d` | `sbhd_sbhd_sbhd` |
| qkv_format=`bshd` | `bs3hd`                | `bsh3d` | `bshd_bs2hd` | `bshd_bsh2d` | `bshd_bshd_bshd` |
| qkv_format=`thd`  | `t3hd`                 | `th3d`  | `thd_t2hd`   | `thd_th2d`   | `thd_thd_thd`    |

Transformer Engine supports all 15 layouts in PyTorch, and 3 layouts, `bs3hd`, `bshd_bs2hd` and `bshd_bshd_bshd`, in JAX and Paddle. A utility function in PyTorch is [transformer_engine.pytorch.attention._get_qkv_layout](../../transformer_engine/pytorch/attention.py) to help users figure out what `qkv_layout` they have.

Transformer Engine, as of v1.7, has the following support matrix for different QKV formats:

| Backend | Supported QKV Formats |
| :--------------- | :-------------------- |
| TriDao attention | `bshd`, `sbhd`, `thd` (`sbhd` requires transpose operations) |
| cuDNN attention  | `bshd`, `sbhd`, `thd`  |
| Framework-native attention | `bshd`, `sbhd` (`sbhd` requires transpose operations) |

In Pytorch, when RoPE is employed, the QKV layout may change. If the initial QKV layout is not in the `hd_hd_hd` QKV group, `transformer_engine.pytorch.attention._get_qkv_layout` will convert it to the `hd_hd_hd` group. For example, `sbh3d` will be converted to `sbhd_sbhd_sbhd`.

### 3.2 Attention Mask

Transformer Engine supports 5 mask types: `no_mask`, `padding`, `causal`, `padding_causal` (equivalent to `causal_padding`), and `arbitrary`. All masks are defined as `True - masking the corresponding element out` and `False - including the element in attention calculation`.

The support matrix for attention mask is:

| Backend          | Supported Mask Types  | Require Mask Tensor |
| :--------------- | :-------------------- | :------------------ |
| TriDao attention | `no_mask`, `causal`, `padding`, `padding_causal` | `no_mask`, `causal`: no<br>`padding`, `padding_causal`: yes if `cu_seqlens` not provided|
| cuDNN attention  | `no_mask`, `causal`, `padding`, `padding_causal` | No |
| Framework-native attention | `no_mask`, `causal`, `arbitrary` | `no_mask`, `causal`: no<br>`arbitrary`: yes |

For `padding` and `padding_causal` mask types, an `attention_mask` tensor should be provided for TriDao attention. For self-attention, `attention_mask` should be one tensor in shape `[batch_size, 1, 1, seqlen_q]`, and for cross-attention, it should be a list of two tensors in shapes `[batch_size, 1, 1, seqlen_q]` and `[batch_size, 1, 1, seqlen_kv]`. For example,

Alternatively, users can pass in `cu_seqlens` tensors. In the case where `cu_seqlens` and `attention_mask` are both passed in, Transformer Engine will pick `cu_seqlens` to save extra compute from `get_cu_seqlens()`. Example:

For `qkv_format=thd`, if `max_seqlen_q` and `max_seqlen_kv` are not present, Transformer Engine will extract them from the `q`, `k`, `v` tensors. This may cost a GPU-CPU copy as well as a synchronization operation, so it's recommended that users set `max_seqlen_q` and `max_seqlen_kv` when running with `thd` layouts.

As of v1.7, cuDNN attention does not support `Arbitrary` masks. However, users can try the `post_scale_bias` path to apply the mask. An example script to convert the mask to a bias and call cuDNN attention is [here](./arbitrary_mask_to_bias.py). This path is more performant than the unfused path.

Since v2.1, `flash-attention` changed its implementation for `causal` mask in cross attention (see [here](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag)). Please note that in this case, `flash-attention` uses the bottom right diagonal while cuDNN attention uses the top left.


### 3.3 Attention Bias

Transformer Engine supports 4 bias types: `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes).

| Backend | Bias Type | Bias Shape | Bias Dtype | Architecture |
| :------ | :-------- | :--------- | :--------- | :----------- |
| TriDao attention           | `no_bias`, `ALiBi` (with slopes) | NA | AliBi slopes: FP32 | sm80+ |
| cuDNN attention            | `no_bias`, `post_scale_bias`, `ALiBi` (without slopes) | `post_scale_bias`: BHSS, 1HSS, B1SS, 11SS for forward and 1HSS for backward | `post_scale_bias`: same as data dtype<br>ALiBi slopes: FP32 | cuDNN 8.9.6+: sm90<br>cuDNN 9.0+: sm80, 90 |
| Framework-native attention | `no_bias`, `pre_scale_bias`, `post_scale_bias` | `post_scale_bias`: BHSS, 1HSS, B1SS, 11SS | `post_scale_bias`: same as data dtype | sm80+ |

TriDao attention enables `ALiBi` bias by user passing in a `alibi_slopes` tensor. This can be the default slopes that come with vanila ALiBi, or custom slopes from the user. On the other hand, cuDNN attention supports ALiBi by taking a boolean flag rather than a slopes tensor. As of v8.9.6, it only supports vanila ALiBi calculations.

The framework-native attention backends do not explicitly support ALiBi, however, users can generate a `post_scale_bias` equivalent to the ALiBi bias. An example of this is the `_get_alibi()` in `transformer_engine.pytorch`. Example in test.

### 3.4 FP8 Attention

Transformer Engine supports FP8 attention on the C level () and Python level (). In v1.6, it also added support on the framework level (PyTorch only).

fp8_dpa, fp8_mha
extra states
te.sequencial

### 3.5 Other Features

Users can also mix the forward and backward of cuDNN attention and TriDao attention, by setting `NVTE_FUSED_ATTN_USE_FAv2_BWD=1`. This helps when TriDao attention backward is faster than cuDNN attention backward.