Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate FlashAttention into Megatron-LM #267

Merged
merged 2 commits into from Jan 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions README.md
Expand Up @@ -333,6 +333,18 @@ Theoretical memory savings vary depending on the combination of the model's para
| bf16 param, fp32 grads | 18 | 6 + 12/d |
| fp32 param, fp32 grads | 16 | 8 + 8/d |

## FlashAttention

Usage: `--use-flash-attn`. Support attention head dimensions at most 128.

[FlashAttention](https://github.com/HazyResearch/flash-attention) is a fast and
memory-efficient algorithm to compute exact attention. It speeds up model
training and reduces memory requirement.

To install FlashAttention:
```sh
pip install flash-attn
```

## GPT-3 Example

Expand Down
3 changes: 3 additions & 0 deletions megatron/arguments.py
Expand Up @@ -612,6 +612,9 @@ def _add_training_args(parser):
group.add_argument('--no-bias-dropout-fusion', action='store_false',
help='Disable bias and dropout fusion.',
dest='bias_dropout_fusion')
group.add_argument('--use-flash-attn', action='store_true',
help='use FlashAttention implementation of attention. '
'https://arxiv.org/abs/2205.14135')
group.add_argument('--optimizer', type=str, default='adam',
choices=['adam', 'sgd'],
help='Optimizer function')
Expand Down
90 changes: 85 additions & 5 deletions megatron/model/transformer.py
Expand Up @@ -15,6 +15,16 @@
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu

try:
from einops import rearrange
except ImportError:
rearrange = None

try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
except ImportError:
flash_attn_unpadded_func = None


""" We use the following notation throughout this file:
h: hidden size
Expand Down Expand Up @@ -306,6 +316,48 @@ def forward(self, query_layer, key_layer,
return context_layer


class FlashSelfAttention(torch.nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
device=None, dtype=None):
super().__init__()
assert flash_attn_unpadded_func is not None, ('Please install FlashAttention first, '
'e.g., with pip install flash-attn')
assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout

def forward(self, q, k, v):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
assert q.dtype in [torch.float16, torch.bfloat16]
assert q.is_cuda
batch_size, seqlen = q.shape[0], q.shape[1]
q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
max_s = seqlen
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=q.device)
output = flash_attn_unpadded_func(
q, k, v, cu_seqlens, cu_seqlens, max_s, max_s,
self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=self.causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
return output


class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.

Expand All @@ -323,6 +375,19 @@ def __init__(self, init_method,
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
self.params_dtype = args.params_dtype
self.sequence_parallel = args.sequence_parallel

self.use_flash_attn = args.use_flash_attn
if self.use_flash_attn:
if flash_attn_unpadded_func is None:
raise ImportError('FlashAttention is not installed, please install with '
'pip install flash-attn')
assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports '
'self-attention for now')
assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only '
'supports causal mask for now')
if rearrange is None:
raise ImportError('einops is not installed, please install with pip install einops')

projection_size = args.kv_channels * args.num_attention_heads

Expand Down Expand Up @@ -365,6 +430,11 @@ def __init__(self, init_method,
self.attn_mask_type)
self.checkpoint_core_attention = args.recompute_granularity == 'selective'

if self.use_flash_attn:
self.core_attention_flash = FlashSelfAttention(
causal=True, attention_dropout=args.attention_dropout
)

# Output.
self.dense = tensor_parallel.RowParallelLinear(
projection_size,
Expand Down Expand Up @@ -487,12 +557,22 @@ def forward(self, hidden_states, attention_mask,
# core attention computation
# ==================================

if self.checkpoint_core_attention:
context_layer = self._checkpointed_attention_forward(
query_layer, key_layer, value_layer, attention_mask)
if not self.use_flash_attn:
if self.checkpoint_core_attention:
context_layer = self._checkpointed_attention_forward(
query_layer, key_layer, value_layer, attention_mask)
else:
context_layer = self.core_attention(
query_layer, key_layer, value_layer, attention_mask)
else:
context_layer = self.core_attention(
query_layer, key_layer, value_layer, attention_mask)
q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous()
for x in (query_layer, key_layer, value_layer)]
if not self.sequence_parallel:
with tensor_parallel.get_cuda_rng_tracker().fork():
context_layer = self.core_attention_flash(q, k, v)
else:
context_layer = self.core_attention_flash(q, k, v)
context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()

# =================
# Output. [sq, b, h]
Expand Down