You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I just found that in GPTAttention Plugin, a distinct context_fmha kernel (mentioned in #457) is used only when mEnableContextFMHA && !isCrossAttention() && !isRelativePosition(). in cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp. Will the setting "context_fmha" (not FMHA runner) save the memory if I apply "cross_attention" ?
Another question for the combo setting of "context_fmha" and "cross_attention":
For the enqueueContext() in /cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp, const size_t q_buf_2_size = mEnableContextFMHA ? 0 : sizeof(T) * params.batch_size * params.input_seq_length * local_hidden_units_qo;
the variable q_buf_2_size should be 0 because of mEnableContextFMHA, thus T* q_buf_2_ = reinterpret_cast<T*>(nextWorkspacePtr(workspace_byte_ptr, offset, q_buf_2_size)); is actually nullptr right? Then how the Gemm operation mCublasWrapper->stridedBatchedGemm() performs with q_buf_2 in this case?
The text was updated successfully, but these errors were encountered:
Hi @xinji1 , I would refer to the discussion under #444 where I explained the current progress regarding cross-attention / relative position bias + fused MHA.
Yes, using FMHA will have both faster speed and less memory consumption.
First let me divide FMHA usage into different categories:
encoder-only model's self-attention, like BERT; OR encoder-decoder model's encoder's self-attention w/o relative bias, like BART's encoder; OR encoder-decoder model's decoder's self-attention w/o relative bias during context phase, like BART's decoder's context phase
decoder-only model's self-attention during context phase, like GPT's context phase
encoder-decoder model's self-attention w/ relative bias, whichever encoder/decoder, like T5's encoder self-attention, T5's decoder self-attention during context phase
encoder-decoder model's decoder's cross-attention during context phase, whichever w/ or w/o relative bias, like BART's or T5's decoder cross-attention during context phase
For (1) and (2), it's supported; for (3) and (4), they're not enabled yet so we will be using unfused path before it's ready.
Back to your second question, we have noticed this and already applied a fix that will be out this week on the dev main branch. Before the fix, user is not expected to set context_fmha flag for enc-dec models, it will report error otherwise; after the fix, when user accidentally sets the context_fmha flag, it will be automatically turned off by checking the cross attention and relative position bias existence.
I just found that in GPTAttention Plugin, a distinct context_fmha kernel (mentioned in #457) is used only when
mEnableContextFMHA && !isCrossAttention() && !isRelativePosition()
. incpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp
. Will the setting "context_fmha" (not FMHA runner) save the memory if I apply "cross_attention" ?Another question for the combo setting of "context_fmha" and "cross_attention":
For the
enqueueContext()
in/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp
,const size_t q_buf_2_size = mEnableContextFMHA ? 0 : sizeof(T) * params.batch_size * params.input_seq_length * local_hidden_units_qo;
the variable
q_buf_2_size
should be0
because ofmEnableContextFMHA
, thusT* q_buf_2_ = reinterpret_cast<T*>(nextWorkspacePtr(workspace_byte_ptr, offset, q_buf_2_size));
is actuallynullptr
right? Then how the Gemm operationmCublasWrapper->stridedBatchedGemm()
performs withq_buf_2
in this case?The text was updated successfully, but these errors were encountered: