diff --git a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst index a12f1108e4..05357e673b 100644 --- a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst +++ b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst @@ -494,6 +494,102 @@ smdistributed.modelparallel.torch.DistributedOptimizer ``state_dict`` contains elements corresponding to only the current partition, or to the entire model. +smdistributed.modelparallel.torch.nn.FlashAttentionLayer +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. function:: smdistributed.modelparallel.torch.nn.FlashAttentionLayer(attention_dropout_prob=0.1, attention_head_size=None, scale_attention_scores=True, scale_attn_by_layer_idx=False, layer_idx=None, scale=None, triton_flash_attention=False, use_alibi=False) + + This class supports + `FlashAttention `_ + for PyTorch 2.0. + It takes the ``qkv`` matrix as an argument through its ``forward`` class method, + computes attention scores and probabilities, + and then operates the matrix multiplication with value layers. + + Through this class, the smp library supports + custom attention masks such as Attention with + Linear Biases (ALiBi), and you can activate them by setting + ``triton_flash_attention`` and ``use_alibi`` to ``True``. + + Note that the Triton flash attention does not support dropout + on the attention probabilities. It uses standard lower triangular + causal mask when causal mode is enabled. It also runs only + on P4d and P4de instances, with fp16 or bf16. + + This class computes the scale factor to apply when computing attention. + By default, ``scale`` is set to ``None``, and it's automatically calculated. + When ``scale_attention_scores`` is ``True`` (which is default), you must pass a value + to ``attention_head_size``. When ``scale_attn_by_layer_idx`` is ``True``, + you must pass a value to ``layer_idx``. If both factors are used, they are + multiplied as follows: ``(1/(sqrt(attention_head_size) * (layer_idx+1)))``. + This scale calculation can be bypassed if you specify a custom scaling + factor to ``scale``. In other words, if you specify a value to ``scale``, the set of parameters + (``scale_attention_scores``, ``attention_head_size``, ``scale_attn_by_layer_idx``, ``layer_idx``) + is overridden and ignored. + + **Parameters** + + * ``attention_dropout_prob`` (float): (default: 0.1) specifies dropout probability + to apply to attention. + * ``attention_head_size`` (int): Required when ``scale_attention_scores`` is True. + When ``scale_attention_scores`` is passed, this contributes + ``1/sqrt(attention_head_size)`` to the scale factor. + * ``scale_attention_scores`` (boolean): (default: True) determines whether + to multiply 1/sqrt(attention_head_size) to the scale factor. + * ``layer_idx`` (int): Required when ``scale_attn_by_layer_idx`` is ``True``. + The layer id to use for scaling attention by layer id. + It contributes 1/(layer_idx + 1) to the scaling factor. + * ``scale_attn_by_layer_idx`` (boolean): (default: False) determines whether + to multiply 1/(layer_idx + 1) to the scale factor. + * ``scale`` (float) (default: None): If passed, this scale factor will be + applied bypassing the all of the previous arguments. + * ``triton_flash_attention`` (bool): (default: False) If passed, Triton + implementation of flash attention will be used. This is necessary to supports + Attention with Linear Biases (ALiBi) (see next arg). Note that this version + of the kernel doesn’t support dropout. + * ``use_alibi`` (bool): (default: False) If passed, it enables Attention with + Linear Biases (ALiBi) using the mask provided. + + .. method:: forward(self, qkv, attn_mask=None, causal=False) + + Returns a single ``torch.Tensor`` ``(batch_size x num_heads x seq_len x head_size)``, + which represents the output of attention computation. + + **Parameters** + + * ``qkv``: ``torch.Tensor`` in the form of ``(batch_size x seqlen x 3 x num_heads x head_size)``. + * ``attn_mask``: ``torch.Tensor`` in the form of ``(batch_size x 1 x 1 x seqlen)``. + By default it is ``None``, and usage of this mask needs ``triton_flash_attention`` + and ``use_alibi`` to be set. See how to generate the mask in the following code snippet. + * ``causal``: When passed, it uses the standard lower triangular mask. The default is ``False``. + + When using ALiBi, it needs an attention mask prepared like the following. + + .. code:: python + + def generate_alibi_attn_mask(attention_mask, batch_size, seq_length, + num_attention_heads, alibi_bias_max=8): + + device, dtype = attention_mask.device, attention_mask.dtype + alibi_attention_mask = torch.zeros( + 1, num_attention_heads, 1, seq_length, dtype=dtype, device=device + ) + + alibi_bias = torch.arange(1 - seq_length, 1, dtype=dtype, device=device).view( + 1, 1, 1, seq_length + ) + m = torch.arange(1, num_attention_heads + 1, dtype=dtype, device=device) + m.mul_(alibi_bias_max / num_attention_heads) + alibi_bias = alibi_bias * (1.0 / (2 ** m.view(1, num_attention_heads, 1, 1))) + + alibi_attention_mask.add_(alibi_bias) + alibi_attention_mask = alibi_attention_mask[..., :seq_length, :seq_length] + if attention_mask is not None and attention_mask.bool().any(): + alibi_attention_mask.masked_fill( + attention_mask.bool().view(batch_size, 1, 1, seq_length), float("-inf") + ) + + return alibi_attention_mask smdistributed.modelparallel.torch Context Managers and Util Functions ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^