Skip to content

Commit

Permalink
fix Is_equal_qk error (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
sneaxiy committed Jan 11, 2024
1 parent fd6890c commit 5fc132a
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion csrc/flash_attn/src/flash_fwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
const bool return_softmax = params.p_ptr != nullptr;
const bool is_attn_mask = params.attn_mask_ptr != nullptr;
const bool is_equal_qk = (params.seqlen_q == params.seqlen_k) && (Is_causal) && (!is_attn_mask);
const bool is_equal_qk = (params.cu_seqlens_q == nullptr) && (params.cu_seqlens_k == nullptr) && (params.seqlen_q == params.seqlen_k) && (Is_causal) && (!is_attn_mask);
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
Expand Down

0 comments on commit 5fc132a

Please sign in to comment.