# Attention Is All You Need!

The attention mechanism [[1]](https://arxiv.org/abs/1706.03762) is the core idea behind Transformer-based models. It identifies the highest correlations amongst words and allows the model to focus on different parts of the sentence. Transformer Engine supports three frameworks, PyTorch, JAX and Paddle, and has implementation for the attention calculation in each framework as [transformer_engine.pytorch.DotProductAttention](../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention), [transformer_engine.jax.flax.DotProductAttention](../api/jax.rst#transformer_engine.jax.flax.DotProductAttention), and [transformer_engine.paddle.DotProductAttention](../api/paddle.rst#transformer_engine.paddle.DotProductAttention).

<figure align="center">
<img src="attn.png" width="70%">
<figcaption> Figure 1: A standard dot product attention workflow, where pre-softmax operations include scaling, bias and masking and post-softmax operations include dropout.</figcaption>
</figure>

## Attention Backends

For each supported framework, Transformer Engine provides multiple backends for both performance and coverage purposes. The full list of backends and their locations in Transformer Engine is as follows.

| Framework | Backend (Module Name)                                                                                                                          | Module Location                                                                                              |
| :------------ | :----------------------------------------------------------------------------------------------------------------------------------------------- | :-------------------------------------------------------------------------------------------------- |
| PyTorch   | - cuDNN attention (`FusedAttention`)<br> - TriDao attention (`FlashAttention`)<br> - PyTorch native implementation (`UnfusedDotProductAttention`) | [pytorch.attention](../../transformer_engine/pytorch/attention.py)      |
| JAX       | - cuDNN attention (`_FusedDotProductAttention`)<br> - JAX native implementation (`_UnfusedDotProductAttention`)                                | [jax.flax.transformer](../../transformer_engine/jax/flax/transformer.py)   |
| Paddle    | - cuDNN attention (`_te_forward`)<br> - Paddle native implementation (`_pd_forward`)                                                           | [paddle.layer.attention](../../transformer_engine/paddle/layer/attention.py) |


The framework-native backends provide a baseline for testing any new features added, while other backends, cuDNN attention and TriDao attention, provide more performance. TriDao attention is the public flash attention implementation: [flash-attention](https://github.com/Dao-AILab/flash-attention). It provides more and more performance as it has evolved from `flash-attention` v1 to `flash-attention` v2. Transformer Engine regularly updates its installation [requirement](../../setup.py), and as of v1.7, Transformer Engine uses flash-attention 2.0.6+. `FlashAttention` is a PyTorch module wrapped around `flash-attention`, providing miscellaneous functionalities such as converting the attention mask to `cu_seqlens`.

[cuDNN attention](https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-flash-attention-fprop), developed at NVIDIA, is another high-performance attention implementation. A few points worth noting between cuDNN attention and TriDao attention are,
- cuDNN attention has been integrated into all three frameworks in Transformer Engine, while TriDao attention is only supported in PyTorch
- cuDNN attention is often named "fused attention" in Transformer Engine, but some of its sub-backends are also implemented based on the "flash" algorithm, as `flash-attention` is.

## Backend Selection

Transformer Engine selects the attention backend based on the user input, such as sequence length, number of heads, head dimension, QKV layout, mask type, bias type, GPU architecture and cuDNN/`flash-attention` version. It prefers the fused implementations over the unfused, and when the fused versions are unavailable, it falls back to the unfused implementations. An example of this selection logic is,

| Framework | Selection Order                                                                                                                              |
| :--------- | :---------------------------------------------------------------------------------------------------------------------------------------------------- |
| PyTorch   | sm90: cuDNN attention -> TriDao attention -> PyTorch native implementation<br>sm80: TriDao attention -> cuDNN attention -> PyTorch native implementation |
| JAX       | cuDNN attention -> JAX native implementation                                                                                                          |
| Paddle    | cuDNN attention -> Paddle native implementation                                                                                                       |

This example is for cuDNN 9.0, and it's more performant than `flash-attention` on sm90 (Hopper) but less so on sm80 (Ampere) - hence the logic in the PyTorch row. As we monitor the performance of different backends, the selection logic may change too. Usually there is no need for users to concern themselves with backend selection; however, if there is performance regression, convergence issue, or just pure curiousity, users can turn on or off some of these backends. For example, for PyTorch, users can disable TriDao attention or cuDNN attention by setting
```
        NVTE_FLASH_ATTN = 0 # disables flash-attention; default = 1
        NVTE_FUSED_ATTN = 0 # disables cuDNN attention; default = 1
```
in their environment.

As cuDNN evolved over time, multiple sub-backends emerged in order to support different sequence lengths and precisions. This table details the three sub-backends in cuDNN attention.

| cuDNN Sub-Backend           | Precision | Sequence Length | Algorithm | Architecture |
| :-------------------------------- | :-------------- | :----------------- | :----------- | :-------------- |
| 0 - `NVTE_F16_max512_seqlen`       | BF16/FP16       | <=512           | Non-Flash | sm80, 90 |
| 1 - `NVTE_F16_arbitrary_seqlen`    | BF16/FP16       | Any             | Flash     | sm80+    |  
| 2 - `NVTE_FP8`   | FP8       | cuDNN pre-9.0: <=512 <br>cuDNN 9.0+: any           | Flash     | cuDNN pre-9.0: sm90 <br>cuDNN 9.0+:  sm90+ |

Users can influence the sub-backend selection by setting
```
        NVTE_FUSED_ATTN_BACKEND = 0 # or 1, or 2
```
but this election only takes effect if that backend does support user's input. If not, Transformer Engine will route to the more appropriate sub-backend. To check which backend, or cuDNN sub-backend, is being used, during runtime, users can set
```
        NVTE_DEBUG = 1
```
and print out more debugging information, for example,
```
        [DotProductAttention]: using flash-attn 2.1.0
        [DotProductAttention]: using cuDNN attention (backend 0)
        [DotProductAttention]: using cuDNN attention (backend 1)
        [DotProductAttention]: using cuDNN attention (backend 2)
```

## Backend Support

All attention backends have support for self-/cross-attention and dropout, but they differ in some other important features. 

| Backend    | Precision      | Architecture | Sliding Window Attention       | MQA/GQA        | Context Parallelism     | Deterministic    |
| :------------------------------ | :-------------- | :------------ | :------------------------------ | :------------------------------ | :----------------------------------- | :----------- |
| cuDNN attention                            | BF16/FP16/FP8  |  sm80+        | No                             | Yes                            | No <br> (`bshd`, `sbhd`, `thd`) | Yes (sub-backend 0, 2),<br>Yes (sub-backend 1 if workspace optimization path in use) |
| TriDao attention                           | BF16/FP16      |  sm80+        | Yes                            | Yes                            | Yes (`bshd`)                      | Yes if `determinism=True`                                                                                    |
| Framework-native attention  | BF16/FP16/FP32 |  Any          | Yes (PyTorch),<br> No (JAX/Paddle) | Yes | No                                  | Yes |

Please note that cuDNN attention and TriDao attention's architecture support may vary for different versions. Please check [here](https://docs.nvidia.com/deeplearning/cudnn/latest/) and [here](https://github.com/Dao-AILab/flash-attention) when running with a specific version. 

The "workspace optimization path" in the table is a feature in cuDNN sub-backend 1 that allows users to trade memory for performance. It uses `batch_size x seqlen_q x seqlen_kv` more memory, but provides 20-30% better performance (available on Hopper only). By default, if the extra memory requirement falls under 256MB, this path is turned on, unless users disable it. The following environment variable allows for more fine-grained control: 
```
# CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT
# - unset: enables workspace optimization when required workspace is <= 256MB
#          or when bias gradient needs to be computed
# -     n: enables workspace optimization when required workspace is <= n bytes
# -    -1: enables workspace optimization always
# -     0: disables workspace optimization always
```
If `deterministic=True` is set for Transformer Engine's `DotProductAttention` module, this path will be turned on as well. The non-workspace optimization path for cuDNN sub-backend 1 is non-deterministic. When choosing the workspace optimization path, please be aware of the Out-Of-Memory risk, due to the increased memory requirement.

## QKV Layout

The query (`q`), key (`k`), value (`v`) tensors that users pass into the `DotProductAttention` module may be in varying memory layouts, and Transformer Engine categorizes them as 15 QKV layouts, 3 QKV formats, or 5 QKV layout groups. Their mapping relationship is as follows.

| qkv_layout          | qkv_layout_group=`3hd` | `h3d`   | `hd_2hd`     | `hd_h2d`     | `hd_hd_hd`       |
| ------------------- | ------------------------ | --------- | -------------- | -------------- | ------------------ |
| qkv_format=`sbhd` | `sb3hd`                | `sbh3d` | `sbhd_sb2hd` | `sbhd_sbh2d` | `sbhd_sbhd_sbhd` |
| `bshd`            | `bs3hd`                | `bsh3d` | `bshd_bs2hd` | `bshd_bsh2d` | `bshd_bshd_bshd` |
| `thd`             | `t3hd`                 | `th3d`  | `thd_t2hd`   | `thd_th2d`   | `thd_thd_thd`    |

Different backends provide different QKV layout support:
- TriDao attention: `bshd`, `sbhd` and `thd` formats (`sbhd` layouts require transposes)
- cuDNN attention: `bshd`, `sbhd` and `thd` formats
- PyTorch-native attention: `bshd` and `sbhd` (`bshd` layouts require transposes)
- JAX-native attention: `bshd` and `sbhd` (transposes?)
- Paddle-native attention: `bshd` and `sbhd` (transposes?)

When RoPE is employed, QKV layout will be automatically converted to the corresponding `hd_hd_hd` layout. For example, `sbh3d` will be converted to `sbhd_sbhd_sbhd`.

## Attention Mask

Transformer Engine supports 5 mask types: `no_mask`, `padding`, `causal`, `padding_causal` (equivalent to `causal_padding`), and `arbitrary`. All masks are defined as `True - masking the corresponding element out` and `False - including the element in attention calculation`.

The support matrix for attention mask is:
- TriDao attention: `no_mask`, `causal`, `padding`, `padding_causal`
- cuDNN attention: `no_mask`, `causal`, `padding`, `padding_causal`
- Framework-native attention: `no_mask`, `causal`, `arbitrary` (?)

Note that TriDao attention since 2.1 has employed a different causal mask definition than cuDNN attention (see [here](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag)). TriDao attention employs bottom right diagonal and cuDNN attention employs top left diagonal.

The `no_mask` and `causal` mask types do not require users to pass in the `attention_mask` tensor, but `padding`, `padding_causal` and `arbitrary` types do, for most of the backends.

For `padding` and `padding_causal`, one `attention_mask` tensor should be passed in in the shape of `[batch_size, 1, 1, seqlen_q]` when the attention type is self-attention, and two tensors are required when it's cross attention in shapes of `[batch_size, 1, 1, seqlen_q]` and `[batch_size, 1, 1, seqlen_kv]`. For example, (?)

This mask is converted to `cu_seqlens_q` (if self-attention) and `cu_seqlens_kv` (if cross attention). Users can pass in these tensors directly as well, instead of `attention_mask`. Example (?)

For `arbitrary` mask type, an `attention_mask` tensor should be passed in in the shape of `[batch_size, 1, 1, seqlen_q]` (?) Example.

cuDNN attention does not support `Arbitray` mask. However, users can use its `post_scale_bias` path to apply an `arbitrary` mask. An example script is [here](./arbitrary_mask_to_bias.py). This path is more performant than the unfused path.

When `max_seqlen_q` and `max_seqlen_kv` are not set, Transformer Engine calculates the max based on the `query`, `key` and `value` tensors, or `cu_seqlens_q/kv` (?). This incurs a GPU-CPU memcpy and synchronization, to avoid which, users are encouraged (? should?) pass in `max_seqlen_q/kv` directly to Transformer Engine.


## Attention Bias

Transformer Engine supports 4 bias types: `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes).

| Backend                    | Bias Type                                                                             | Bias Shape                                                                                           | Bias Dtype                                                       | Architecture            |
| -------------------------- | ------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------- | ----------------------- |
| TriDao attention           | \`no_bias\`, \`AliBi\` (with/without slopes) need to convert?                         | NA                                                                                                   | FP32 for AliBi slopes                                            | sm80+                   |
| cuDNN attention            | \`no_bias\`, \`post_scale_bias\`, \`AliBi\` (with/without slopes) need to convert?    | BHSS, 1HSS, B1SS, 11SS for \`post_scale_bias\` forward,<br>And 1HSS for \`post_scale_bias\` backward | Data dtype for \`post_scale_bias\`,<br>And FP32 for AliBi slopes | sm80, 90 for cuDNN 9.0+ |
| Framework-native attention | \`no_bias\`, \`pre_scale_bias\`, \`post_scale_bias\`, \`AliBi\` (with/without slopes) | BHSS, 1HSS, B1SS, 11SS for \`post_scale_bias\`                                                       | Data dtype for \`post_scale_bias\`,<br>And FP32 for AliBi slopes | sm80+                   |

ALiBi vs ALiBi slopes
can convert by self? get_alibi(), get_alibi_slopes()
Mask + Bias?

### 2.4 Other Features

Users can also mix the forward and backward of cuDNN attention and TriDao attention, by setting `NVTE_FUSED_ATTN_USE_FAv2_BWD=1`. This helps when TriDao attention backward is faster than cuDNN attention backward.


### 2.5 FP8 attention
fp8_dpa, fp8_mha
extra states
