Skip to content

Commit

Permalink
Update flash attention section in memory_optimizations.rst (NVIDIA#9188)
Browse files Browse the repository at this point in the history
* Update flash attention section in memory_optimizations.rst

Signed-off-by: cyanguwa <8636796+cyanguwa@users.noreply.github.com>

* update changes based on comments

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------

Signed-off-by: cyanguwa <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Boxiang Wang <boxiangw@nvidia.com>
  • Loading branch information
cyanguwa authored and BoxiangW committed Jun 5, 2024
1 parent aa61248 commit 1d55c54
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions docs/source/features/memory_optimizations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,26 @@ Flash Attention
Overview
^^^^^^^^

Flash Attention is a method designed to enhance the efficiency of Transformer models, which are widely utilized in applications such as Natural Language Processing (NLP). Traditional Transformers are slow and consume a lot of memory, especially with long sequences, due to the quadratic time and memory complexity of self-attention. Flash Attention is an IO-aware exact attention algorithm that leverages tiling to minimize the number of memory reads/writes between the GPU's high-bandwidth memory (HBM) and on-chip SRAM. This approach is designed to be more efficient in terms of IO complexity compared to standard attention mechanisms.
Flash attention is an algorithm designed to improve the efficiency of the attention mechanism in transformer models such as GPT and BERT. The attention mechanism has quadratic time and memory complexity in sequence length and can present significant runtime and memory challenges for longer sequences.

Compared to the standard, non-flash algorithm, flash attention applies two techniques to lower the memory requirement and improve compute efficiency.

The tiling technique decomposes the inputs based on the shared memory size and calculates the softmax one tile at a time. Instead of working on the entire query, key, value tensors at once, it makes several passes at these tensors and then combines the results in a subsequent step.

The recomputation technique stores the softmax normalization factors (linear to sequence length), instead of the softmax results (qudratic to sequence length), and uses these normalization factors to recompute the attention scores. This saves the amount of data to write to global memory and reduces both the memory requirement and I/O traffic between global memory and shared memory.

Flash attention lowers the memory footprint and computational complexity from quadratic to linear, and greatly extending the range of sequence length allowed in large language models.

The flash attention algorithm was first propsed `here <https://arxiv.org/pdf/2205.14135>`_. Two of its implementations are `flash-attention <https://github.com/Dao-AILab/flash-attention>`_ by Tri Dao *et al*, and `fused flash attention <https://docs.nvidia.com/deeplearning/cudnn/archives/cudnn-897/developer-guide/index.html#flash-fused-multi-head-att-fprop>`_ by NVIDIA cuDNN.

Turn Flash Attention On and Off
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In the NeMo Framework, Flash Attention is supported through the Transformer Engine with the inclusion of Flash Attention 2. By default, Flash Attention is enabled, but the Transformer Engine may switch to a different kernel if the tensor dimensions are not optimal for Flash Attention. Users can completely disable Flash Attention by setting the environment variable ``NVTE_FLASH_ATTN=0``.
In the NeMo framework, flash attention is supported through `Transformer Engine <https://github.com/NVIDIA/TransformerEngine/tree/main>`_, including both of the implementations mentioned above. Transformer Engine selects the appropriate implementation based on input information such as sequence length, number of heads and head dimension. When both implementations are applicable, Transformer Engine prefers cuDNN flash attention on Hopper+ architectures and Tri Dao flash attention on Ampere architectures.

To disable Tri Dao flash attention, set the environment variable ``NVTE_FLASH_ATTN=0``. To disable cuDNN flash attention, set ``NVTE_FUSED_ATTN=0``.

For more details on the supported Dot Attention backend, please refer to the Transformer Engine source code available at `Transformer Engine's Attention Mechanism <https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py>`_.
For more details on the Dot Product Attention backends supported in Transformer Engine, please refer to the source code at `Transformer Engine's Attention Mechanism <https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py>`_.

Activation Recomputation
------------------------
Expand All @@ -28,15 +40,15 @@ Overview

Full Activation Recomputation
"""""""""""""""""""""""""""""
This method recalculates all the intermediate activations during the backward pass of a model's training, instead of storing them during the forward pass. This technique maximizes memory efficiency at the cost of computational overhead, as each activation is recomputed when needed.
The full activation recomputation method recalculates all the intermediate activations during the backward pass of a model's training, instead of storing them during the forward pass. This technique maximizes memory efficiency at the cost of computational overhead, as each activation is recomputed when needed.

Partial Activation Recomputation
""""""""""""""""""""""""""""""""
This method recomputes only a subset of layers during the backward phase. It is a trade-off between the full recomputation and no recomputation, balancing memory savings with computational efficiency.
The partial activation recomputation method recomputes only a subset of layers during the backward phase. It is a trade-off between the full recomputation and no recomputation, balancing memory savings with computational efficiency.

Selective Activation Recomputation
""""""""""""""""""""""""""""""""""
This method reduces memory footprint of activations significantly via smart activation checkpointing. This approach involves selectively storing only crucial activations and recomputing the others as needed. It is particularly useful in large models to minimize memory usage while controlling the computational cost.
The selective activation recomputation method reduces memory footprint of activations significantly via smart activation checkpointing. This approach involves selectively storing only crucial activations and recomputing the others as needed. It is particularly useful in large models to minimize memory usage while controlling the computational cost.

Refer to "Reducing Activation Recomputation in Large Transformer Models" for more details: https://arxiv.org/abs/2205.05198.

Expand Down

0 comments on commit 1d55c54

Please sign in to comment.