-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add flash-attention patch for falcon-7b (#3580)
Enable the `use_flash_attention` configuration flag for Falcon models. When `use_flash_attention` is set to `true` the [FalconAttention.forwad()](https://github.com/huggingface/transformers/blob/c965d302791cf935d6ea7776428749be678cf509/src/transformers/models/falcon/modeling_falcon.py#L281) method is replaced with a variant that uses Tri Dao's flash_attention instead of pytorch's `scaled_dot_product_attention` function. At the moment the patch works only for falcon-7b but technically it will also work for falcon-40b with the right configuration. The falcon model situation is currently a bit messy: The Falcon model was recently added to Huggingface transformers (see [PR transformers#24523](huggingface/transformers#24523)) but the falcon models on the hugginface hub use still the code which is shipped together with the weights (a PR to change this [was reverted](https://huggingface.co/tiiuae/falcon-7b/discussions/66)). Falcon-7b and 40b use both slightly different code (which was unified in the HF transformers impl and can there be controlled via a configuration member called `new_decoder_architecture` see [configuration_falcon.py#L65-L67](https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/configuration_falcon.py#L65-L67)). The HF Falcon impl uses different names in the configuration class, e.g. compare new [configuration_falcon.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/configuration_falcon.py) and old [configuration_RW.py](https://huggingface.co/tiiuae/falcon-7b/blob/main/configuration_RW.py) HF Falcon implementation compatible model configurations can be found here: 7B: [config.json](https://huggingface.co/tiiuae/falcon-7b/blob/4e2d06f0a7c6370ebabbc30c6f59377ae8f73d76/config.json) 40B: [config.json](https://huggingface.co/tiiuae/falcon-40b/blob/f1ba7d328c06aa6fbb4a8afd3c756f46d7e6b232/config.json)
- Loading branch information
1 parent
aa42ed4
commit 1e6e569
Showing
10 changed files
with
269 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
{ | ||
"bf16": { | ||
"enabled": "auto" | ||
}, | ||
"scheduler": { | ||
"type": "WarmupDecayLR", | ||
"params": { | ||
"warmup_min_lr": "auto", | ||
"warmup_max_lr": "auto", | ||
"warmup_num_steps": "auto", | ||
"warmup_type": "linear", | ||
"total_num_steps": "auto" | ||
} | ||
}, | ||
"optimizer": { | ||
"type": "AdamW", | ||
"params": { | ||
"lr": "auto", | ||
"betas": "auto", | ||
"eps": "auto", | ||
"weight_decay": "auto" | ||
} | ||
}, | ||
"zero_optimization": { | ||
"stage": 3, | ||
"offload_optimizer": { | ||
"device": "cpu", | ||
"pin_memory": true | ||
}, | ||
"offload_param": { | ||
"device": "cpu", | ||
"pin_memory": true | ||
}, | ||
"overlap_comm": true, | ||
"contiguous_gradients": true, | ||
"sub_group_size": 1e9, | ||
"reduce_bucket_size": "auto", | ||
"stage3_prefetch_bucket_size": "auto", | ||
"stage3_param_persistence_threshold": "auto", | ||
"stage3_max_live_parameters": 1e9, | ||
"stage3_max_reuse_distance": 1e9, | ||
"stage3_gather_16bit_weights_on_model_save": true | ||
}, | ||
"gradient_accumulation_steps": "auto", | ||
"gradient_clipping": "auto", | ||
"steps_per_print": 2000, | ||
"train_batch_size": "auto", | ||
"train_micro_batch_size_per_gpu": "auto", | ||
"wall_clock_breakdown": false | ||
} |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from typing import Optional, Tuple | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from .patching_utils import compute_flash_attention | ||
|
||
|
||
def falcon_forward_with_flash_attn( | ||
self, | ||
flash_attn: nn.Module, # flash_attn.modules.mha.FlashSelfAttention | ||
hidden_states: torch.Tensor, | ||
alibi: Optional[torch.Tensor], | ||
attention_mask: torch.Tensor, | ||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | ||
head_mask: Optional[torch.Tensor] = None, | ||
use_cache: bool = False, | ||
output_attentions: bool = False, | ||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | ||
""" | ||
head_mask, alibi & output_attention are not supported. | ||
Reference to the original `FalconAttention.forwad()` method which this patch replaces: | ||
https://github.com/huggingface/transformers/blob/c965d302791cf935d6ea7776428749be678cf509/src/transformers/models/falcon/modeling_falcon.py#L281 | ||
""" | ||
|
||
assert head_mask is None # not supported. | ||
assert alibi is None # not supported. | ||
assert not output_attentions # not supported. | ||
|
||
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] | ||
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads | ||
# 3 x [batch_size, seq_length, num_heads, head_dim] | ||
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) | ||
|
||
batch_size, query_length, _, _ = query_layer.shape | ||
|
||
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim) | ||
key_layer = key_layer.transpose(1, 2).reshape( | ||
batch_size * num_kv_heads, | ||
query_length, | ||
self.head_dim, | ||
) | ||
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim) | ||
|
||
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1] | ||
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length) | ||
|
||
if layer_past is not None: | ||
past_key, past_value = layer_past | ||
# concatenate along seq_length dimension: | ||
# - key: [batch_size * self.num_heads, kv_length, head_dim] | ||
# - value: [batch_size * self.num_heads, kv_length, head_dim] | ||
key_layer = torch.cat((past_key, key_layer), dim=1) | ||
value_layer = torch.cat((past_value, value_layer), dim=1) | ||
|
||
if use_cache: | ||
present = (key_layer, value_layer) | ||
else: | ||
present = None | ||
|
||
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim) | ||
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) | ||
value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) | ||
|
||
q = query_layer_.permute(0, 2, 1, 3) | ||
k = key_layer_.permute(0, 2, 1, 3).expand(q.shape) | ||
v = value_layer_.permute(0, 2, 1, 3).expand(q.shape) | ||
|
||
if attention_mask is not None: | ||
attention_mask = attention_mask[:, 0, -1] | ||
|
||
flash_attn.train(self.training) | ||
attn_output = compute_flash_attention(flash_attn, q, k, v, attention_mask=attention_mask) | ||
attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) | ||
|
||
output_tensor = self.dense(attn_output) | ||
|
||
return output_tensor, present |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.