Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor fused MLP + fused attention loading. Fix for fused MLP requiring Triton even when not used. #85

Conversation

TheBloke
Copy link
Contributor

@TheBloke TheBloke commented May 16, 2023

This is a fix for: #43 (comment)

Changes:

  1. In modeling/_base.py, add new classmethods get_fused_attention_module and get_fused_mlp_module. if called directly they return "this class does not support" warnings.
  2. In modeling/llama.py, add same classmethods, implementing the imports of FusedLlamaMLPForQuantizedModel and FusedLlamaAttentionForQuantizedModel with try except blocks
  3. In modeling/_base.py implement checks for inject_fused_attention and inject_fused_mlp which only call get_fused_mlp_module and get_fused_attention_module if right conditions are met. In particular, don't call get_fused_mlp_module unless use_triton is True.
  4. Therefore FusedLlamaMLPForQuantizedModel will not be imported unless the user specifies use_triton and inject_fused_mlp
  5. Therefore the user does not need Triton installed and can execute CUDA code without errors.

Testing done:

  1. Inference with: CUDA, CUDA + FA, Triton, Triton + FA, Triton + FM. Triton + FA + FM
  2. Quantisation with CUDA

@TheBloke TheBloke mentioned this pull request May 16, 2023
@LexSong
Copy link
Contributor

LexSong commented May 18, 2023

This patch works. Thanks.

@PanQiWei
Copy link
Collaborator

Thank you @TheBloke to create this pr and solved user's problem when try injecting fused module without triton.

But, I think it would be better to fix from the root cause instead of adding new functions as patches for it will make the code more complex.

In my opinion, the better design pattern to keep in "automatic" is keep something to None or use global flag to disable if it's not supported in some specific cases, as those done in flash-attention's block.py and text-generation-inference's layers.py

I will make a new pr to fix the triton problem in a new way, so this pr can be closed.

Once again, I'm really appreciate your contributions and they are truly making this project better and better. ❤️‍🔥

@PanQiWei
Copy link
Collaborator

I've made #92 to fix import error when triton is not installed and optimized code to make it more automatic when trying to integrate with triton

@TheBloke
Copy link
Contributor Author

Looks good! Will test it in a minute

@TheBloke TheBloke closed this May 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants