Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Will the setting "context_fmha" support "cross_attention" ? #461

Closed
xinji1 opened this issue Nov 24, 2023 · 3 comments
Closed

Will the setting "context_fmha" support "cross_attention" ? #461

xinji1 opened this issue Nov 24, 2023 · 3 comments
Assignees
Labels
triaged Issue has been triaged by maintainers

Comments

@xinji1
Copy link

xinji1 commented Nov 24, 2023

I just found that in GPTAttention Plugin, a distinct context_fmha kernel (mentioned in #457) is used only when mEnableContextFMHA && !isCrossAttention() && !isRelativePosition(). in cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp. Will the setting "context_fmha" (not FMHA runner) save the memory if I apply "cross_attention" ?

Another question for the combo setting of "context_fmha" and "cross_attention":
For the enqueueContext() in /cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp,
const size_t q_buf_2_size = mEnableContextFMHA ? 0 : sizeof(T) * params.batch_size * params.input_seq_length * local_hidden_units_qo;

the variable q_buf_2_size should be 0 because of mEnableContextFMHA, thus T* q_buf_2_ = reinterpret_cast<T*>(nextWorkspacePtr(workspace_byte_ptr, offset, q_buf_2_size)); is actually nullptr right? Then how the Gemm operation mCublasWrapper->stridedBatchedGemm() performs with q_buf_2 in this case?

@symphonylyh symphonylyh self-assigned this Nov 24, 2023
@juney-nvidia juney-nvidia added the triaged Issue has been triaged by maintainers label Nov 24, 2023
@symphonylyh
Copy link
Collaborator

Hi @xinji1 , I would refer to the discussion under #444 where I explained the current progress regarding cross-attention / relative position bias + fused MHA.

Yes, using FMHA will have both faster speed and less memory consumption.

First let me divide FMHA usage into different categories:

  1. encoder-only model's self-attention, like BERT; OR encoder-decoder model's encoder's self-attention w/o relative bias, like BART's encoder; OR encoder-decoder model's decoder's self-attention w/o relative bias during context phase, like BART's decoder's context phase
  2. decoder-only model's self-attention during context phase, like GPT's context phase
  3. encoder-decoder model's self-attention w/ relative bias, whichever encoder/decoder, like T5's encoder self-attention, T5's decoder self-attention during context phase
  4. encoder-decoder model's decoder's cross-attention during context phase, whichever w/ or w/o relative bias, like BART's or T5's decoder cross-attention during context phase

For (1) and (2), it's supported; for (3) and (4), they're not enabled yet so we will be using unfused path before it's ready.

Back to your second question, we have noticed this and already applied a fix that will be out this week on the dev main branch. Before the fix, user is not expected to set context_fmha flag for enc-dec models, it will report error otherwise; after the fix, when user accidentally sets the context_fmha flag, it will be automatically turned off by checking the cross attention and relative position bias existence.

@juney-nvidia
Copy link
Collaborator

Closed since already fixed on the main branch.

@xinji1
Copy link
Author

xinji1 commented Nov 25, 2023

got it, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

3 participants