# Attention Is All You Need!

The core idea behind Transformer models is the attention mechanism [[1]](https://arxiv.org/abs/1706.03762). It identifies the correlation between words, focuses on important parts of the sentence, and allows models to more accurately capture patterns in the data and make predictions. [Transformer Engine](https://github.com/NVIDIA/TransformerEngine.git) supports three frameworks, [PyTorch](https://github.com/pytorch/pytorch), [JAX](https://github.com/google/jax) and [PaddlePaddle](https://github.com/PaddlePaddle/Paddle). Their dot product attention interface are respectively [transformer_engine.pytorch.DotProductAttention](../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention), [transformer_engine.jax.flax.DotProductAttention](../api/jax.rst#transformer_engine.jax.flax.DotProductAttention) and [transformer_engine.paddle.DotProductAttention](../api/paddle.rst#transformer_engine.paddle.DotProductAttention).

<figure align="center">
<img src="attn.png" width="70%">
<figcaption> Figure 1: The dot product attention workflow, where pre-softmax operations include scaling, bias, masking and post-softmax operations include dropout (all optional). </figcaption>
</figure>

## 1. Attention Backends

Transformer Engine provides multiple backends for each of the supported frameworks. While the framework-native implementations offer a robust baseline, other backends provide more computational performance. The full list of backends and their definitions are as follows.

| Framework | Backend (Module Name) | Module Location |
| :-------- | :-------------------- | :-------------- |
| PyTorch   | cuDNN attention (`FusedAttention`)<br> TriDao attention (`FlashAttention`)<br> Native PyTorch implementation (`UnfusedDotProductAttention`) | [transformer_engine.pytorch.attention](../../transformer_engine/pytorch/attention.py)      |
| JAX       | cuDNN attention (`_FusedDotProductAttention`)<br> Native JAX implementation (`_UnfusedDotProductAttention`)                                | [transformer_engine.jax.flax.transformer](../../transformer_engine/jax/flax/transformer.py)   |
| Paddle    | cuDNN attention (`_te_forward`)<br> Native Paddle implementation (`_pd_forward`)                                                           | [transformer_engine.paddle.layer.attention](../../transformer_engine/paddle/layer/attention.py) |

### 1.1 Flash vs Non-Flash

The attention calculation has quadratic time and memory complexity to the sequence length, and it poses significant challenges for Transformer models to scale up to longer contexts. The flash attention algorithm [[2]](https://arxiv.org/abs/2205.14135) was proposed to improve the compute efficiency and reduce the memory requirement of attention. Compared to the standard, non-flash algorithm, it employs two distinct techniques:
- Tiling: Instead of loading the entire tensors in at the same time, flash attention makes several passes at the input. It decomposes the inputs based on the shared memory size, computes the softmax one block at a time, and combines the results together in a subsequent step.
- Recomputation: The standard, non-flash algorithm stores the softmax matrix (qudratic to sequence length) to the global memory, while flash attention only saves the softmax normalization factors (linear to sequence length) from the forward pass. During backward, it recomputes the forward output using these normalization factors. Even though there is added computation, it happens on-chip and the read/write savings still outweigh the extra recomputation.

It is proven that the flash algorithm runs faster and requires less memory compared to the non-flash attention algorithm.

### 1.2 Public Flash vs cuDNN Flash

The popular public implementation, [flash-attention](https://github.com/Dao-AILab/flash-attention), has evolves significantly from version 1 to version 2. It only supports PyTorch, and Transformer Engine has integrated it into its PyTorch attention module `transformer_engine.pytorch.attention.FusedAttention`. `FusedAttention` is a wrapper around `flash-attention`, and it also provides a few miscellaneous functionalities such as converting an `attention_mask` tensor to cumulative sequence lengths `cu_seqlens`. As of v1.7, Transformer Engine supports `flash-attention` 2.0.6+.

Another high-performance attention implementation is, [cuDNN attention](https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-flash-attention-fprop), developed at NVIDIA. Even though on the framework level, it has been named "fused attention", its sub-backends 1 and 2 (listed below) are also based on the flash algorithm, as `flash-attention` is. cuDNN attention requires [cuDNN](https://developer.nvidia.com/cudnn) and [cudnn-frontend](../../3rdparty/cudnn-frontend) to run, and it has been integrated into Transformer Engine to support all three frameworks (see [Attention Backends](#1.-attention-backends)).

| cuDNN Sub-Backend |  Name | Precision | Sequence Length | Algorithm | Architecture |
| :---------------- | :------------- | :-------------- | :---------- | :----------- | :-------------- |
| 0 | `NVTE_F16_max512_seqlen`       | BF16/FP16       | <=512       | Non-Flash | sm80, 90 |
| 1 | `NVTE_F16_arbitrary_seqlen`    | BF16/FP16       | Any         | Flash     | sm80+    |  
| 2 | `NVTE_FP8`                     | FP8             | cuDNN pre-9.0: <=512<br>cuDNN 9.0+: any | Flash     | cuDNN pre-9.0: sm90<br>cuDNN 9.0+:  sm90+ |

## 2. Backend Selection

Transformer Engine selects the most appropriate attention backend based on user input and backend performance. User input affects whether a backend is eligible, and backend performance heuristics affect which backend to choose when there are multiple  eligible ones.

To understand what input parameters are at play, users can set `NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2` in their run and some example outputs are,
```
        [DotProductAttention]: using flash-attn 2.1.0
        [DotProductAttention]: using cuDNN attention (backend 0)
        [DotProductAttention]: dtype=torch.bfloat16, b=2, sq=2048, skv=2048, hq=16, hkv=2, d=64, qkv_layout='bshd', mask_type='padding', bias_type='post_scale_bias', bias_shape='1hss', dropout=0.1, is_training=True, context_parallel=True, sm=80, cudnn_version=9.0
        [DotProductAttention]: using cuDNN attention (backend 1)
        [DotProductAttention]: dtype=torch.bfloat16, b=2, sq=2048, skv=2048, hq=16, hkv=2, d=64, qkv_layout='bshd', mask_type='padding', bias_type='post_scale_bias', bias_shape='1hss', dropout=0.1, is_training=True, context_parallel=True, sm=80, cudnn_version=9.0
        [DotProductAttention]: using cuDNN attention (backend 2)
        [DotProductAttention]: fp8_dpa=False, fp8_mha=False, FP8_BWD_ONLY=True...
```

On the performance side, fused implementations are usually better than the unfused ones, and our general selection order is,

| Framework | Selection Order                                                                                                                              |
| :--------- | :---------------------------------------------------------------------------------------------------------------------------------------------------- |
| PyTorch   | sm90: cuDNN attention > TriDao attention > PyTorch native implementation<br>sm80: TriDao attention > cuDNN attention > PyTorch native implementation |
| JAX       | cuDNN attention > JAX native implementation                                                                                                          |
| Paddle    | cuDNN attention > Paddle native implementation                                                                                                       |

This order may change as we monitor different backends' performance across various model configurations and on different GPU architectures.

Users do not need to interfere with the backend selection in Transformer Engine, but if there is a performance regression or convergence issue (or it's simply out of curiousity), these environment variables can be used to turn on/off TriDao attention and cuDNN attention (in PyTorch). 
```
        NVTE_FLASH_ATTN = 0 # disables flash-attention; default = 1
        NVTE_FUSED_ATTN = 0 # disables cuDNN attention; default = 1
```

Users can also *influence* the cuDNN sub-backend selection by
```
        NVTE_FUSED_ATTN_BACKEND = 0 # or 1, 2
```
Note that this election only takes effect if the elected backend does have support for the user input and runtime environment. If not, Transformer Engine will still route to the appropriate sub-backend as it would without this user election.

## 3. Backend Support

### 3.1 Basic Features

All attention backends support basic features such as self-/cross-attention and dropout, and they differ in a few other key features. 

| Backend | Precision      | Architecture | Sliding Window Attention | MQA/GQA | Context Parallelism | Deterministic |
| :------ | :------------- | :----------- | :----------------------- | :------ | :------------------ | :------------ |
| cuDNN attention            | BF16/FP16/FP8  |  sm80+ | No  | Yes | No (`bshd`, `sbhd`, `thd`) | Sub-backend 0, 2: yes<br>Sub-backend 1: yes if workspace optimization path |
| TriDao attention           | BF16/FP16      |  sm80+ | Yes | Yes | Yes (`bshd`)                      | Yes if `deterministic=True`                                                                                    |
| Framework-native attention | BF16/FP16/FP32 |  Any   | No unless passed in as a mask  | Yes | No                                  | Yes |

For architecture-specific support, please check [here](https://docs.nvidia.com/deeplearning/cudnn/latest/) and [here](https://github.com/Dao-AILab/flash-attention) about cuDNN attention and `flash-attention`.

The "workspace optimization path" in the table is a feature in cuDNN sub-backend 1 that allows users to trade memory for performance. It uses `batch_size x seqlen_q x seqlen_kv` more memory, but provides 20-30% better performance (available on Hopper only). By default, if the extra memory requirement falls under 256MB, this path is turned on, unless users disable it. The following environment variable allows for more fine-grained control: 
```
# CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT
# - unset: enables workspace optimization when required workspace is <= 256MB
#          or when bias gradient needs to be computed
# -     n: enables workspace optimization when required workspace is <= n bytes
# -    -1: enables workspace optimization always
# -     0: disables workspace optimization always
```
If `deterministic=True` is set for Transformer Engine's `DotProductAttention` module, this path will be turned on as well. The non-workspace optimization path for cuDNN sub-backend 1 is non-deterministic. When choosing the workspace optimization path, please be aware of the Out-Of-Memory risk, due to the increased memory requirement.

### 3.2 QKV Layout

The query (`q`), key (`k`), value (`v`) tensors that users pass into the `DotProductAttention` module may be in varying memory layouts, and Transformer Engine categorizes them as 15 QKV layouts, 3 QKV formats, or 5 QKV layout groups. Their mapping relationship is as follows.

| qkv_layout          | qkv_layout_group=`3hd` | `h3d`   | `hd_2hd`     | `hd_h2d`     | `hd_hd_hd`       |
| ------------------- | ------------------------ | --------- | -------------- | -------------- | ------------------ |
| qkv_format=`sbhd` | `sb3hd`                | `sbh3d` | `sbhd_sb2hd` | `sbhd_sbh2d` | `sbhd_sbhd_sbhd` |
| `bshd`            | `bs3hd`                | `bsh3d` | `bshd_bs2hd` | `bshd_bsh2d` | `bshd_bshd_bshd` |
| `thd`             | `t3hd`                 | `th3d`  | `thd_t2hd`   | `thd_th2d`   | `thd_thd_thd`    |

Different backends provide different QKV layout support:
- TriDao attention: `bshd`, `sbhd` and `thd` formats (`sbhd` layouts require transposes)
- cuDNN attention: `bshd`, `sbhd` and `thd` formats
- PyTorch-native attention: `bshd` and `sbhd` (`bshd` layouts require transposes)
- JAX-native attention: `bshd` and `sbhd` (transposes?)
- Paddle-native attention: `bshd` and `sbhd` (transposes?)

When RoPE is employed, QKV layout will be automatically converted to the corresponding `hd_hd_hd` layout. For example, `sbh3d` will be converted to `sbhd_sbhd_sbhd`.

### 3.3 Attention Mask

Transformer Engine supports 5 mask types: `no_mask`, `padding`, `causal`, `padding_causal` (equivalent to `causal_padding`), and `arbitrary`. All masks are defined as `True - masking the corresponding element out` and `False - including the element in attention calculation`.

The support matrix for attention mask is:
- TriDao attention: `no_mask`, `causal`, `padding`, `padding_causal`
- cuDNN attention: `no_mask`, `causal`, `padding`, `padding_causal`
- Framework-native attention: `no_mask`, `causal`, `arbitrary` (?)

Note that TriDao attention since 2.1 has employed a different causal mask definition than cuDNN attention (see [here](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag)). TriDao attention employs bottom right diagonal and cuDNN attention employs top left diagonal.

The `no_mask` and `causal` mask types do not require users to pass in the `attention_mask` tensor, but `padding`, `padding_causal` and `arbitrary` types do, for most of the backends.

For `padding` and `padding_causal`, one `attention_mask` tensor should be passed in in the shape of `[batch_size, 1, 1, seqlen_q]` when the attention type is self-attention, and two tensors are required when it's cross attention in shapes of `[batch_size, 1, 1, seqlen_q]` and `[batch_size, 1, 1, seqlen_kv]`. For example, (?)

This mask is converted to `cu_seqlens_q` (if self-attention) and `cu_seqlens_kv` (if cross attention). Users can pass in these tensors directly as well, instead of `attention_mask`. Example (?)

For `arbitrary` mask type, an `attention_mask` tensor should be passed in in the shape of `[batch_size, 1, 1, seqlen_q]` (?) Example.

cuDNN attention does not support `Arbitray` mask. However, users can use its `post_scale_bias` path to apply an `arbitrary` mask. An example script is [here](./arbitrary_mask_to_bias.py). This path is more performant than the unfused path.

When `max_seqlen_q` and `max_seqlen_kv` are not set, Transformer Engine calculates the max based on the `query`, `key` and `value` tensors, or `cu_seqlens_q/kv` (?). This incurs a GPU-CPU memcpy and synchronization, to avoid which, users are encouraged (? should?) pass in `max_seqlen_q/kv` directly to Transformer Engine.


### 3.3 Attention Bias

Transformer Engine supports 4 bias types: `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes).

| Backend                    | Bias Type                                                                             | Bias Shape                                                                                           | Bias Dtype                                                       | Architecture            |
| -------------------------- | ------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------- | ----------------------- |
| TriDao attention           | \`no_bias\`, \`AliBi\` (with/without slopes) need to convert?                         | NA                                                                                                   | FP32 for AliBi slopes                                            | sm80+                   |
| cuDNN attention            | \`no_bias\`, \`post_scale_bias\`, \`AliBi\` (with/without slopes) need to convert?    | BHSS, 1HSS, B1SS, 11SS for \`post_scale_bias\` forward,<br>And 1HSS for \`post_scale_bias\` backward | Data dtype for \`post_scale_bias\`,<br>And FP32 for AliBi slopes | sm80, 90 for cuDNN 9.0+ |
| Framework-native attention | \`no_bias\`, \`pre_scale_bias\`, \`post_scale_bias\`, \`AliBi\` (with/without slopes) | BHSS, 1HSS, B1SS, 11SS for \`post_scale_bias\`                                                       | Data dtype for \`post_scale_bias\`,<br>And FP32 for AliBi slopes | sm80+                   |

ALiBi vs ALiBi slopes
can convert by self? get_alibi(), get_alibi_slopes()
Mask + Bias?

### 3.4 FP8 Attention
fp8_dpa, fp8_mha
extra states


### 3.5 Other Features

Users can also mix the forward and backward of cuDNN attention and TriDao attention, by setting `NVTE_FUSED_ATTN_USE_FAv2_BWD=1`. This helps when TriDao attention backward is faster than cuDNN attention backward.
