diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 89200f3381..a71e47828b 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -1196,36 +1196,31 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += __expf(lse_accum(l) - lse_max); } SumOp sum_op; lse_sum = Allreduce::run(lse_sum, sum_op); + + if constexpr(Has_sink){ + const int row = tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + if (row < params.num_splits && col < kBlockM) { + const index_t lse_offset = row_offset_lse + col; + if(params.unpadded_lse){ + // LSE is written as (h, seqlen_q, b) or (h, b, seqlen_q). + const int head_idx = lse_offset / (params.b * params.seqlen_q); + const float sink_val_exp = params.learnable_sink_ptr == nullptr ? 0.f : __expf(reinterpret_cast(params.learnable_sink_ptr)[head_idx] - lse_max); + lse_sum += sink_val_exp; + }else{ + // LSE is written as (b, h, seqlen_q). + const int head_idx = (lse_offset % (params.h * params.seqlen_q)) / params.seqlen_q; + const float sink_val_exp = params.learnable_sink_ptr == nullptr ? 0.f : __expf(reinterpret_cast(params.learnable_sink_ptr)[head_idx] - lse_max); + lse_sum += sink_val_exp; + } + } + } + // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : __logf(lse_sum) + lse_max; // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } - if constexpr (Has_sink) { - #pragma unroll - for (int l = 0; l < kNLsePerThread; ++l) { - const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; - const int col = tidx / kRowsPerLoadTranspose; - if (row < params.num_splits && col < kBlockM) { - const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose; - if (params.unpadded_lse) { - // LSE is written as (h, seqlen_q, b) or (h, b, seqlen_q). - if (lse_offset < lse_size) { - const int head_idx = lse_offset / (params.b * params.seqlen_q); - const float lse_logsum_sink = __expf(lse_logsum); - const float sink_val_exp = params.learnable_sink_ptr == nullptr ? 0.f : __expf(reinterpret_cast(params.learnable_sink_ptr)[head_idx]); - lse_logsum = __logf(lse_logsum_sink + sink_val_exp);; - } - } else { - // LSE is written as (b, h, seqlen_q). - const int head_idx = (lse_offset % (params.h * params.seqlen_q)) / params.seqlen_q; - const float lse_logsum_sink = __expf(lse_logsum); - const float sink_val_exp = params.learnable_sink_ptr == nullptr ? 0.f : __expf(reinterpret_cast(params.learnable_sink_ptr)[head_idx]); - lse_logsum = __logf(lse_logsum_sink + sink_val_exp); - } - } - } - } if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { if (params.unpadded_lse) { const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose;