Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix propagation of use_flash for offloaded inference #178

Merged
merged 1 commit into from Jul 28, 2022

Conversation

epenning
Copy link
Contributor

This change fixes an issue where _forward_offload() was being passed the parameter use_flash but didn't accept the parameter use_flash. This also passes use_flash along to _prep_blocks() since it's a required parameter.

Context

When running run_pretrained_openfold.py with the config for offload_inference enabled, this error occurred:

Traceback (most recent call last):
  File "/.../shared/openfold/run_pretrained_openfold.py", line 591, in <module>
    main(args)
  File "/.../shared/openfold/run_pretrained_openfold.py", line 439, in main
    out = run_model(model, processed_feature_dict, tag, args)
  File "/.../shared/openfold/run_pretrained_openfold.py", line 119, in run_model
    out = model(batch)
  File "/.../shared/openfold/lib/conda/envs/openfold_venv/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/.../shared/openfold/openfold/model/model.py", line 510, in forward
    _recycle=(num_iters > 1)
  File "/.../shared/openfold/openfold/model/model.py", line 384, in iteration
    _mask_trans=self.config._mask_trans,
TypeError: _forward_offload() got an unexpected keyword argument 'use_flash'

@gahdritz
Copy link
Collaborator

gahdritz commented Jul 28, 2022

Thanks for the PR! This was actually half-deliberate; FlashAttention doesn't work that well for the long sequence sizes for which the offloading mode is useful. I propose completely removing the use_flash parameter being passed to _forward_offload in model.py and manually passing use_flash=False to _prep_blocks in the same function instead. I'll add a new constraint in the config preventing users from trying to combine the two. I changed my mind. I'll just add a warning to the config and let users do as they please.

@gahdritz gahdritz merged commit 984370c into aqlaboratory:main Jul 28, 2022
@epenning epenning deleted the fix_offload_flash branch July 28, 2022 19:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants