Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flash-attention patch for falcon-7b #3580

Merged
merged 5 commits into from Jul 19, 2023
Merged

Conversation

andreaskoepf
Copy link
Collaborator

@andreaskoepf andreaskoepf commented Jul 17, 2023

Enable the use_flash_attention configuration flag for Falcon models. When use_flash_attention is set to true the FalconAttention.forwad() method is replaced with a variant that uses Tri Dao's flash_attention instead of pytorch's scaled_dot_product_attention function.

At the moment the patch works only for falcon-7b but technically it will also work for falcon-40b with the right configuration. The falcon model situation is currently a bit messy: The Falcon model was recently added to Huggingface transformers (see PR transformers#24523) but the falcon models on the hugginface hub use still the code which is shipped together with the weights (a PR to change this was reverted). Falcon-7b and 40b use both slightly different code (which was unified in the HF transformers impl and can there be controlled via a configuration member called new_decoder_architecture see configuration_falcon.py#L65-L67). The HF Falcon impl uses different names in the configuration class, e.g. compare new configuration_falcon.py and old configuration_RW.py

HF Falcon implementation compatible model configurations can be found here:
7B: config.json
40B: config.json

Copy link
Collaborator

@jordiclive jordiclive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, minor clarification

model/model_training/models/__init__.py Show resolved Hide resolved
model/model_training/models/patching_utils.py Show resolved Hide resolved
model/pyproject.toml Show resolved Hide resolved
Copy link
Collaborator

@jordiclive jordiclive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@andreaskoepf andreaskoepf merged commit 1e6e569 into main Jul 19, 2023
1 check passed
@andreaskoepf andreaskoepf deleted the falcon7b_flash_attn_patch branch July 19, 2023 14:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants