Add flash-attention patch for falcon-7b #3580
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Enable the
use_flash_attention
configuration flag for Falcon models. Whenuse_flash_attention
is set totrue
the FalconAttention.forwad() method is replaced with a variant that uses Tri Dao's flash_attention instead of pytorch'sscaled_dot_product_attention
function.At the moment the patch works only for falcon-7b but technically it will also work for falcon-40b with the right configuration. The falcon model situation is currently a bit messy: The Falcon model was recently added to Huggingface transformers (see PR transformers#24523) but the falcon models on the hugginface hub use still the code which is shipped together with the weights (a PR to change this was reverted). Falcon-7b and 40b use both slightly different code (which was unified in the HF transformers impl and can there be controlled via a configuration member called
new_decoder_architecture
see configuration_falcon.py#L65-L67). The HF Falcon impl uses different names in the configuration class, e.g. compare new configuration_falcon.py and old configuration_RW.pyHF Falcon implementation compatible model configurations can be found here:
7B: config.json
40B: config.json