# Attention Is All You Need!

The core idea behind Transformer models is the attention mechanism, which calculates the correlations between words, assigns different levels of importance to different words, and allows the model to focus on different parts of the sentence. Transformer Engine has implementation for the attention mechanism for three frameworks:
- [transformer_engine.pytorch.DotProductAttention](../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention)
- [transformer_engine.jax.flax.DotProductAttention](../api/jax.rst#transformer_engine.jax.flax.DotProductAttention)
- [transformer_engine.paddle.DotProductAttention](../api/paddle.rst#transformer_engine.paddle.DotProductAttention)

<figure align="center">
<img src="dot_product_attention.png" width="15%">
<figcaption> Figure 1: Dot product attention <a href="https://arxiv.org/abs/1706.03762">[1]</a>. </figcaption>
</figure>

## 1. Supported Backends
Multiple backends are provided for dot product attention calculation in Transformer Engine.

| Framework | Available Backends                                                                                                                          | Files                                                                                              |
| --------- | ------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------- |
| PyTorch   | - cuDNN attention (`FusedAttention`)<br> - TriDao attention (`FlashAttention`)<br> - PyTorch native implementation (`UnfusedDotProductAttention`) | [pytorch.DotProductAttention](../../transformer_engine/pytorch/attention.py)      |
| JAX       | - cuDNN attention (`_FusedDotProductAttention`)<br> - JAX native implementation (`_UnfusedDotProductAttention`)                                | [jax.flax.DotProductAttention](../../transformer_engine/jax/flax/transformer.py)   |
| Paddle    | - cuDNN attention (`_te_forward`)<br> - Paddle native implementation (`_pd_forward`)                                                           | [paddle.DotProductAttention](../../transformer_engine/paddle/layer/attention.py) |

While framework-native implementations provide a robust baseline and wide coverage for various use cases, the other backends aim to improve performance. For example, [TriDao attention](https://github.com/Dao-AILab/flash-attention) and [cuDNN attention](https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-flash-attention-fprop) both provide significant speedup compared to the native PyTorch implementation ([here](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#performance) and [here](https://github.com/NVIDIA/cudnn-frontend/tree/main/benchmark)). 

Transformer Engine selects the appropriate backend based on user input, such as sequence length, number of heads, head dimension, QKV layout, attention mask type, bias type, and GPU architecture. With performance in mind, Transformer Engine prefers the fused versions of attention over the unfused. When support is not available, Transformer Engine moves further down the selection chain.

| Framework | Backend Selection Order                                                                                                                              |
| --------- | ---------------------------------------------------------------------------------------------------------------------------------------------------- |
| PyTorch   | sm90: cuDNN attention > TriDao attention > PyTorch native implementation<br>sm80: TriDao attention > cuDNN attention > PyTorch native implementation |
| JAX       | cuDNN attention > JAX native implementation                                                                                                          |
| Paddle    | cuDNN attention > Paddle native implementation                                                                                                       |

For PyTorch, two environment variables are allowed to control which backend to use.
```
NVTE_FLASH_ATTN = 0 # disables TriDao attention; default = 1
NVTE_FUSED_ATTN = 0 # disables cuDNN attention; default = 1
```

cuDNN attention also has different versions of the attention kernels to support different precision, sequence length, etc.

| cuDNN Backend           | Precision | Sequence Length | Algorithm | Architecture |
| ----------------------- | --------- | --------------- | --------- | ------------ |
| 0 - `max512_seqlen`       | BF16/FP16       | <=512           | Non-Flash | sm80, 90 |
| 1 - `arbitrary_seqlen`    | BF16/FP16       | Any             | Flash     | sm80+    |  
| 2 - `FP8` + cuDNN pre-9.0 | FP8       | <=512           | Flash     | sm89, 90+ |
| 2 - `FP8` + cuDNN 9.0+    | FP8       | Any             | Flash     | sm89, 90+ |

To control which backend to be selected, users can set:
```
NVTE_FUSED_ATTN_BACKEND = 0 # selects `max512_seqlen`
NVTE_FUSED_ATTN_BACKEND = 1 # selects `arbitrary_seqlen`
NVTE_FUSED_ATTN_BACKEND = 2 # selects `FP8`
```
Note that these elections are only user preference. If the elected backend does not support the input combinations, Transformer Engine will select a different, more appropriate backend. This applies for both `NVTE_FLASH_ATTN`, `NVTE_FUSED_ATTN`, and `NVTE_FUSED_ATTN_BACKEND`. To figure out which backend is used in runtime, users can set `NVTE_DEBUG=1`, and debugging information will be printed as below:
```
[DotProductAttention]: using flash-attn 1.4.0 # TriDao attention is selected
[DotProductAttention]: using cuDNN attention (backend 0) # cuDNN attention `max512_seqlen` backend is selected
[DotProductAttention]: using cuDNN attention (backend 1) # cuDNN attention `arbitrary_seqlen` backend is selected
[DotProductAttention]: using cuDNN attention (backend 1) # cuDNN attention `FP8` backend is selected
```
This is for PyTorch only, and the same debugging information will be added to JAX and Padddle.

## 2. Support Matrix

Some basic feature differences between the supported backends are listed below.

| Framework | Backend                                                           | Precision      | Attention Type | Dropout | Architecture |
| --------- | ----------------------------------------------------------------- | -------------- | -------------- | ------- | ------------ |
| PyTorch   | cuDNN attention (`FusedAttention`)                           | BF16/FP16/FP8  | Self/Cross     | Yes     | sm80+        |
|           | TriDao attention (`FlashAttention`)                          | BF16/FP16      | Self/Cross     | Yes     | sm80+        |
|           | PyTorch native implementation (`UnfusedDotProductAttention`) | BF16/FP16/FP32 | Self/Cross     | Yes     | Any          |
| JAX       | cuDNN attention (`FusedAttention`)                           | BF16/FP16/FP8  | Self/Cross     | Yes     | sm80+        |
|           | JAX native implementation (`_UnfusedDotProductAttention`)    | BF16/FP16/FP32 | Self/Cross     | Yes     | Any          |
| Paddle    | cuDNN attention (`FusedAttention`)                           | BF16/FP16/FP8  | Self/Cross     | Yes     | sm80+        |
|           | Paddle native implementation (`_pd_forward`)                 | BF16/FP16/FP32 | Self/Cross     | Yes     | Any          |


### 2.1 QKV Layout

The query, key and value tensors can come in different layouts, and Transformer Engine categorize them as 15 layouts, 3 formats and 5 layout groups.

| qkv_layout          | qkv_layout_group=`3hd` | `h3d`   | `hd_2hd`     | `hd_h2d`     | `hd_hd_hd`       |
| ------------------- | ------------------------ | --------- | -------------- | -------------- | ------------------ |
| qkv_format=`sbhd` | `sb3hd`                | `sbh3d` | `sbhd_sb2hd` | `sbhd_sbh2d` | `sbhd_sbhd_sbhd` |
| `bshd`            | `bs3hd`                | `bsh3d` | `bshd_bs2hd` | `bshd_bsh2d` | `bshd_bshd_bshd` |
| `thd`             | `t3hd`                 | `th3d`  | `thd_t2hd`   | `thd_th2d`   | `thd_thd_thd`    |

Different backends provide different QKV layout support:
- TriDao attention: `bshd`, `sbhd` and `thd` formats (`sbhd` layouts require transposes)
- cuDNN attention: `bshd`, `sbhd` and `thd` formats
- PyTorch-native attention: `bshd` and `sbhd` (`bshd` layouts require transposes)
- JAX-native attention: `bshd` and `sbhd` (transposes?)
- Paddle-native attention: `bshd` and `sbhd` (transposes?)

When RoPE is employed, QKV layout will be automatically converted to the corresponding `hd_hd_hd` layout. For example, `sbh3d` will be converted to `sbhd_sbhd_sbhd`.

### 2.2 Attention Mask

Transformer Engine supports 5 mask types: `no_mask`, `padding`, `causal`, `padding_causal` (equivalent to `causal_padding`), and `arbitrary`. All masks are defined as `True - masking the corresponding element out` and `False - including the element in attention calculation`.

The support matrix for attention mask is:
- TriDao attention: `no_mask`, `causal`, `padding`, `padding_causal`
- cuDNN attention: `no_mask`, `causal`, `padding`, `padding_causal`
- Framework-native attention: `no_mask`, `causal`, `arbitrary` (?)

Note that TriDao attention since 2.1 has employed a different causal mask definition than cuDNN attention (see [here](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag)). TriDao attention employs bottom right diagonal and cuDNN attention employs top left diagonal.

The `no_mask` and `causal` mask types do not require users to pass in the `attention_mask` tensor, but `padding`, `padding_causal` and `arbitrary` types do, for most of the backends.

For `padding` and `padding_causal`, one `attention_mask` tensor should be passed in in the shape of `[batch_size, 1, 1, seqlen_q]` when the attention type is self-attention, and two tensors are required when it's cross attention in shapes of `[batch_size, 1, 1, seqlen_q]` and `[batch_size, 1, 1, seqlen_kv]`. For example, (?)

This mask is converted to `cu_seqlens_q` (if self-attention) and `cu_seqlens_kv` (if cross attention). Users can pass in these tensors directly as well, instead of `attention_mask`. Example (?)

For `arbitrary` mask type, an `attention_mask` tensor should be passed in in the shape of `[batch_size, 1, 1, seqlen_q]` (?) Example.

cuDNN attention does not support `Arbitray` mask. However, users can use its `post_scale_bias` path to apply an `arbitrary` mask. An example script is [here](./arbitrary_mask_to_bias.py). This path is more performant than the unfused path.

When `max_seqlen_q` and `max_seqlen_kv` are not set, Transformer Engine calculates the max based on the `query`, `key` and `value` tensors, or `cu_seqlens_q/kv` (?). This incurs a GPU-CPU memcpy and synchronization, to avoid which, users are encouraged (? should?) pass in `max_seqlen_q/kv` directly to Transformer Engine.


### 2.3 Attention Bias

Transformer Engine supports 4 bias types: `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes).

| Backend                    | Bias Type                                                                             | Bias Shape                                                                                           | Bias Dtype                                                       | Architecture            |
| -------------------------- | ------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------- | ----------------------- |
| TriDao attention           | \`no_bias\`, \`AliBi\` (with/without slopes) need to convert?                         | NA                                                                                                   | FP32 for AliBi slopes                                            | sm80+                   |
| cuDNN attention            | \`no_bias\`, \`post_scale_bias\`, \`AliBi\` (with/without slopes) need to convert?    | BHSS, 1HSS, B1SS, 11SS for \`post_scale_bias\` forward,<br>And 1HSS for \`post_scale_bias\` backward | Data dtype for \`post_scale_bias\`,<br>And FP32 for AliBi slopes | sm80, 90 for cuDNN 9.0+ |
| Framework-native attention | \`no_bias\`, \`pre_scale_bias\`, \`post_scale_bias\`, \`AliBi\` (with/without slopes) | BHSS, 1HSS, B1SS, 11SS for \`post_scale_bias\`                                                       | Data dtype for \`post_scale_bias\`,<br>And FP32 for AliBi slopes | sm80+                   |

ALiBi vs ALiBi slopes
can convert by self? get_alibi(), get_alibi_slopes()
Mask + Bias?

### 2.4 Other Features

| Backend                    | Sliding Window Attention       | MQA/GQA                        | Context Parallelism                 | Determinism                                                                                                      |
| -------------------------- | ------------------------------ | ------------------------------ | ----------------------------------- | ---------------------------------------------------------------------------------------------------------------- |
| TriDao attention           | Yes                            | Yes                            | Yes (\`bshd\`)                      | Yes when \`determinism=True\`                                                                                    |
| cuDNN attention            | No                             | Yes                            | No (\`bshd\`, \`sbhd\`, \`thd\`) | \`max512_seqlen\`: yes<br>\`arbitrary_seqlen\`: yes if workspace optimization path<br>\`FP8\`: yes |
| Framework-native attention | Yes (PyTorch),<br> No (JAX/Paddle) | Yes (PyTorch),<br> No (JAX/Paddle) | No                                  | Yes                                                                                                              |

The `arbitrary_seqlen` backend of cuDNN attention also supports two paths: workspace-optimization path, and non-workspace optimization path. The workspace-optimization path uses more memory (size `[b, 1, s, s]`) but provides better runtime performance. By default, it is on when the required extra space is <= 256MB; otherwise off. Users can define this environment variable to choose a particular path at will. Be mindful of the OOM risk.
```
    # CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT
    # - unset:       enables workspace optimization when required workspace is <= 256MB
    #                or when bias gradient needs to be computed
    # - n:           enables workspace optimization when required workspace is <= n bytes
    # - -1:          enables workspace optimization always
    # - 0:           disables workspace optimization always
```

Users can also mix the forward and backward of cuDNN attention and TriDao attention, by setting `NVTE_FUSED_ATTN_USE_FAv2_BWD=1`. This helps when TriDao attention backward is faster than cuDNN attention backward.


### 2.5 FP8 attention
fp8_dpa, fp8_mha
extra states
