Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 20 additions & 25 deletions csrc/flash_attn/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -1196,36 +1196,31 @@ inline __device__ void combine_attn_seqk_parallel(const Params &params) {
for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += __expf(lse_accum(l) - lse_max); }
SumOp<float> sum_op;
lse_sum = Allreduce<kRowsPerLoadTranspose>::run(lse_sum, sum_op);

if constexpr(Has_sink){
const int row = tidx % kRowsPerLoadTranspose;
const int col = tidx / kRowsPerLoadTranspose;
if (row < params.num_splits && col < kBlockM) {
const index_t lse_offset = row_offset_lse + col;
if(params.unpadded_lse){
// LSE is written as (h, seqlen_q, b) or (h, b, seqlen_q).
const int head_idx = lse_offset / (params.b * params.seqlen_q);
const float sink_val_exp = params.learnable_sink_ptr == nullptr ? 0.f : __expf(reinterpret_cast<float *>(params.learnable_sink_ptr)[head_idx] - lse_max);
lse_sum += sink_val_exp;
}else{
// LSE is written as (b, h, seqlen_q).
const int head_idx = (lse_offset % (params.h * params.seqlen_q)) / params.seqlen_q;
const float sink_val_exp = params.learnable_sink_ptr == nullptr ? 0.f : __expf(reinterpret_cast<float *>(params.learnable_sink_ptr)[head_idx] - lse_max);
lse_sum += sink_val_exp;
}
}
}

// For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise
// lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum.
ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : __logf(lse_sum) + lse_max;
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }

if constexpr (Has_sink) {
#pragma unroll
for (int l = 0; l < kNLsePerThread; ++l) {
const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
const int col = tidx / kRowsPerLoadTranspose;
if (row < params.num_splits && col < kBlockM) {
const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose;
if (params.unpadded_lse) {
// LSE is written as (h, seqlen_q, b) or (h, b, seqlen_q).
if (lse_offset < lse_size) {
const int head_idx = lse_offset / (params.b * params.seqlen_q);
const float lse_logsum_sink = __expf(lse_logsum);
const float sink_val_exp = params.learnable_sink_ptr == nullptr ? 0.f : __expf(reinterpret_cast<float *>(params.learnable_sink_ptr)[head_idx]);
lse_logsum = __logf(lse_logsum_sink + sink_val_exp);;
}
} else {
// LSE is written as (b, h, seqlen_q).
const int head_idx = (lse_offset % (params.h * params.seqlen_q)) / params.seqlen_q;
const float lse_logsum_sink = __expf(lse_logsum);
const float sink_val_exp = params.learnable_sink_ptr == nullptr ? 0.f : __expf(reinterpret_cast<float *>(params.learnable_sink_ptr)[head_idx]);
lse_logsum = __logf(lse_logsum_sink + sink_val_exp);
}
}
}
}
if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) {
if (params.unpadded_lse) {
const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose;
Expand Down