diff --git a/3rdparty/composable_kernel b/3rdparty/composable_kernel index 7968368d92..08792e0b31 160000 --- a/3rdparty/composable_kernel +++ b/3rdparty/composable_kernel @@ -1 +1 @@ -Subproject commit 7968368d9217436d1e8a6ecf587ac0a0d580fcd0 +Subproject commit 08792e0b31b936b9e7baa05fb8b03dce8c21241a diff --git a/aiter/ops/mha.py b/aiter/ops/mha.py index 2424cb7955..9a611aa9ef 100644 --- a/aiter/ops/mha.py +++ b/aiter/ops/mha.py @@ -624,7 +624,8 @@ def cmdGenFunc_mha_bwd( alibi_slopes: Optional[Tensor] = None, rng_state: Optional[Tensor] = None, gen: Optional[Generator] = None, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ): md_name = "mha_bwd" filter1 = "*" # get_bwd_dot_do_o_blobs() @@ -775,7 +776,8 @@ def gen_mha_bwd_fake_tensors( alibi_slopes: Optional[Tensor] = None, rng_state: Optional[Tensor] = None, gen: Optional[Generator] = None, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: return common_mha_bwd_fake_tensors(q, k, v, dq, dk, dv) @@ -807,7 +809,8 @@ def mha_bwd( alibi_slopes: Optional[Tensor] = None, rng_state: Optional[Tensor] = None, gen: Optional[Generator] = None, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... @@ -889,7 +892,8 @@ def cmdGenFunc_mha_varlen_bwd( gen: Optional[Generator] = None, cu_seqlens_q_padded: Optional[Tensor] = None, cu_seqlens_k_padded: Optional[Tensor] = None, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ) -> dict[str, Any]: md_name = "mha_varlen_bwd" filter1 = "*" # get_bwd_dot_do_o_blobs() @@ -1117,7 +1121,8 @@ def gen_mha_varlen_bwd_fake_tensors( alibi_slopes: Optional[Tensor] = None, rng_state: Optional[Tensor] = None, gen: Optional[Generator] = None, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: return gen_mha_varlen_bwd_fake_tensors_common( q, k, v, cu_seqlens_q, max_seqlen_q, zero_tensors, dq, dk, dv @@ -1157,7 +1162,8 @@ def mha_varlen_bwd( gen: Optional[Generator] = None, cu_seqlens_q_padded: Optional[Tensor] = None, cu_seqlens_k_padded: Optional[Tensor] = None, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... @@ -1567,7 +1573,8 @@ def _flash_attn_backward_fake( rng_state: Optional[torch.Tensor] = None, is_v3_atomic_fp32: Optional[bool] = True, how_v3_bf16_cvt: Optional[int] = 1, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ) -> torch.Tensor: batch_size = q.size(0) seqlen_q = q.size(1) @@ -1606,7 +1613,8 @@ def _flash_attn_backward( rng_state: Optional[torch.Tensor] = None, is_v3_atomic_fp32: Optional[bool] = True, how_v3_bf16_cvt: Optional[int] = 1, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ) -> torch.Tensor: # rtna & rtz are deprecated in gfx950 if get_gfx() == "gfx950" and how_v3_bf16_cvt != 0: @@ -1723,7 +1731,8 @@ def can_impl_fmha_v3_bwd_gfx950(): alibi_slopes, rng_state, None, - sink_ptr, + sink, + d_sink, # custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd}, ) return softmax_d @@ -1782,7 +1791,7 @@ def forward( how_v3_bf16_cvt=how_v3_bf16_cvt, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, - sink_ptr=sink_ptr, + sink_ptr=sink_ptr, # fwd kernel still uses sink_ptr naming ) if is_grad: assert return_lse @@ -1840,7 +1849,8 @@ def backward(ctx, dout, *args): rng_state, ctx.is_v3_atomic_fp32, ctx.how_v3_bf16_cvt, - sink_ptr=None, + sink=None, + d_sink=None, ) dq = dq[..., :head_size_q_og] # We could have padded the head dimension dk = dk[..., :head_size_q_og] @@ -1863,7 +1873,10 @@ def backward(ctx, dout, *args): # 15 how_v3_bf16_cvt # 16 cu_seqlens_q # 17 cu_seqlens_kv - # Need to return exactly 17 gradient entries. + # 18 sink_ptr (fwd-only sink scores; not differentiable via autograd. + # bwd sink gradient d_sink is computed inside mha_bwd kernel, + # not returned here as a positional gradient.) + # Need to return exactly 18 gradient entries. return ( dq, # q dk, # k @@ -1882,7 +1895,7 @@ def backward(ctx, dout, *args): None, # how_v3_bf16_cvt None, # cu_seqlens_q None, # cu_seqlens_kv - None, # sink_ptr + None, # sink_ptr (not differentiable; bwd uses sink/d_sink args separately) ) @@ -2170,7 +2183,8 @@ def _flash_attn_varlen_backward( zero_tensors: bool = False, cu_seqlens_q_padded: Optional[torch.Tensor] = None, cu_seqlens_k_padded: Optional[torch.Tensor] = None, - sink_ptr: Optional[Tensor] = None, + sink: Optional[Tensor] = None, + d_sink: Optional[Tensor] = None, ) -> torch.Tensor: _, nhead_q, hdim_q = q.shape @@ -2332,7 +2346,8 @@ def can_impl_fmha_v3_bwd_gfx950(): None, cu_seqlens_q_padded, cu_seqlens_k_padded, - sink_ptr=sink_ptr, + sink=sink, + d_sink=d_sink, # custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd}, ) return softmax_d @@ -2487,7 +2502,8 @@ def backward(ctx, dout, *args): how_v3_bf16_cvt=ctx.how_v3_bf16_cvt, cu_seqlens_q_padded=ctx.cu_seqlens_q_padded, cu_seqlens_k_padded=ctx.cu_seqlens_k_padded, - sink_ptr=None, + sink=None, + d_sink=None, ) dq = dq[..., :head_size_q_og] # We could have padded the head dimension dk = dk[..., :head_size_q_og] @@ -2508,7 +2524,10 @@ def backward(ctx, dout, *args): # out, # is_grad_enabled, # cu_seqlens_q_padded, cu_seqlens_k_padded, - # is_v3_atomic_fp32, how_v3_bf16_cvt + # is_v3_atomic_fp32, how_v3_bf16_cvt, + # sink_ptr (fwd-only sink scores; not differentiable via autograd. + # bwd sink gradient d_sink is computed inside mha_varlen_bwd kernel, + # not returned here as a positional gradient.) # We only have gradients for q,k,v (dq,dk,dv) and possibly bias (dbias). Others are None. return ( dq, # q @@ -2536,7 +2555,7 @@ def backward(ctx, dout, *args): None, # cu_seqlens_k_padded None, # is_v3_atomic_fp32 None, # how_v3_bf16_cvt - None, # sink_ptr + None, # sink_ptr (not differentiable; bwd uses sink/d_sink args separately) ) diff --git a/csrc/cpp_itfs/mha_bwd.cu b/csrc/cpp_itfs/mha_bwd.cu index 5a4d261eba..105fb35401 100644 --- a/csrc/cpp_itfs/mha_bwd.cu +++ b/csrc/cpp_itfs/mha_bwd.cu @@ -168,6 +168,8 @@ float mha_bwd(mha_bwd_args a, const ck_tile::stream_config& s) /* dv_ptr */ a.dv_ptr, /* dbias_ptr */ a.dbias_ptr, /* dq_acc_ptr */ a.dq_acc_ptr, + /* sink_ptr */ a.sink_ptr, + /* d_sink_ptr */ a.d_sink_ptr, /* seqstart_q_ptr */ a.seqstart_q_ptr, /* seqstart_k_ptr */ a.seqstart_k_ptr, @@ -208,7 +210,7 @@ float mha_bwd(mha_bwd_args a, const ck_tile::stream_config& s) /* nhead_stride_randval*/ a.nhead_stride_randval, /* nhead_stride_do */ a.nhead_stride_do, /* nhead_stride_lsed */ a.nhead_stride_lsed, - /* nhead_stride_dq_acc*/ a.nhead_stride_dq_acc, + /* nhead_stride_dq_acc*/ static_cast(a.nhead_stride_dq_acc), /* nhead_stride_dq */ a.nhead_stride_dq, /* nhead_stride_dk */ a.nhead_stride_dk, /* nhead_stride_dv */ a.nhead_stride_dv, @@ -222,7 +224,7 @@ float mha_bwd(mha_bwd_args a, const ck_tile::stream_config& s) /* batch_stride_randval*/ a.batch_stride_randval, /* batch_stride_do */ a.batch_stride_do, /* batch_stride_lsed */ a.batch_stride_lsed, - /* batch_stride_dq_acc*/ a.batch_stride_dq_acc, + /* batch_stride_dq_acc*/ static_cast(a.batch_stride_dq_acc), /* batch_stride_dq */ a.batch_stride_dq, /* batch_stride_dk */ a.batch_stride_dk, /* batch_stride_dv */ a.batch_stride_dv, diff --git a/csrc/include/mha_bwd.h b/csrc/include/mha_bwd.h index f8460678b1..0afaeadd22 100644 --- a/csrc/include/mha_bwd.h +++ b/csrc/include/mha_bwd.h @@ -47,6 +47,8 @@ struct mha_bwd_args void* dv_ptr; void* dbias_ptr; void* dq_acc_ptr; + const void* sink_ptr = nullptr; // sink scores [batch, nhead] log-space (LSEDataType=float); nullptr disables sink + void* d_sink_ptr = nullptr; // sink gradient accumulator [nhead] (LSEDataType=float); nullptr disables sink grad // Usage notes for sequence length pointer parameters: // // [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index c2d4aa000a..37ca2680a3 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -818,7 +818,9 @@ namespace py = pybind11; py::arg("bias") = std::nullopt, \ py::arg("alibi_slopes") = std::nullopt, \ py::arg("rng_state") = std::nullopt, \ - py::arg("gen") = std::nullopt); + py::arg("gen") = std::nullopt, \ + py::arg("sink") = std::nullopt, \ + py::arg("d_sink") = std::nullopt); #define MHA_FWD_ASM_PYBIND \ m.def("fmha_v3_fwd", \ @@ -950,7 +952,9 @@ namespace py = pybind11; py::arg("rng_state") = std::nullopt, \ py::arg("gen") = std::nullopt, \ py::arg("cu_seqlens_q_padded") = std::nullopt, \ - py::arg("cu_seqlens_k_padded") = std::nullopt); + py::arg("cu_seqlens_k_padded") = std::nullopt, \ + py::arg("sink") = std::nullopt, \ + py::arg("d_sink") = std::nullopt); #define MOE_CK_2STAGES_PYBIND \ m.def("ck_moe_stage1", \ diff --git a/csrc/include/torch/mha_bwd.h b/csrc/include/torch/mha_bwd.h index 5b1ea2c098..ce8fcf8ca1 100644 --- a/csrc/include/torch/mha_bwd.h +++ b/csrc/include/torch/mha_bwd.h @@ -24,6 +24,8 @@ std::vector mha_bwd(const at::Tensor& dout, // [b, sq, hq, d] std::optional bias_, // [sq, sk] std::optional alibi_slopes, // [hq] or [b, hq] std::optional rng_state, - std::optional gen); + std::optional gen, + std::optional sink, // [b, hq] log-space sink scores (float) + std::optional d_sink); // [hq] sink gradient output (float) } // namespace torch_itfs } // namespace aiter diff --git a/csrc/include/torch/mha_varlen_bwd.h b/csrc/include/torch/mha_varlen_bwd.h index ac78ec2fb3..e5ba6f0754 100644 --- a/csrc/include/torch/mha_varlen_bwd.h +++ b/csrc/include/torch/mha_varlen_bwd.h @@ -30,7 +30,9 @@ mha_varlen_bwd(const at::Tensor& dout, // [total_q, hq, d] std::optional rng_state, std::optional gen, std::optional cu_seqlens_q_padded, // [b+1] - std::optional cu_seqlens_k_padded // [b+1] + std::optional cu_seqlens_k_padded, // [b+1] + std::optional sink, // [b, hq] log-space sink scores (float) + std::optional d_sink // [hq] sink gradient output (float) ); } // namespace torch_itfs } // namespace aiter diff --git a/csrc/py_itfs_ck/mha_bwd_kernels.cu b/csrc/py_itfs_ck/mha_bwd_kernels.cu index 425817f2c2..2615ed667f 100644 --- a/csrc/py_itfs_ck/mha_bwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_bwd_kernels.cu @@ -31,7 +31,9 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v] std::optional bias_, // [sq, sk] std::optional alibi_slopes_, // [hq] or [b, hq] std::optional rng_state_, - std::optional gen_) + std::optional gen_, + std::optional sink_, // [b, hq] log-space sink scores (float) + std::optional d_sink_) // [hq] sink gradient output (float) { if (is_causal) { window_size_right = 0; } @@ -131,37 +133,19 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v] alibi_slopes_.has_value() ? bias_type = bias_enum::alibi : bias_enum::no_bias; bool has_dbias = dbias_.has_value(); auto opts = q.options(); - const fmha_bwd_traits traits{ - seqlen_q, - seqlen_k, - batch_size, - seqlen_q, // max_seqlen_q - seqlen_k, // max_seqlen_k - head_size_q, - head_size_v, - num_heads, - num_heads_k, - q_dtype_str, - false, // is_group_mode - mask.type, - bias_type, - has_dbias, - p_dropout > 0, - false, // is_store_randval - deterministic, - }; const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard{q.device()}; auto stream = at::hip::getCurrentHIPStream(); auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); - const fmha_bwd_launcher launcher(traits); - const ck_tile::index_t nsplits = launcher.dq_acc_splits; - at::Tensor dq_accum; - if (launcher.needs_zero_dq_acc) - dq_accum = torch::zeros({batch_size, num_heads, nsplits, seqlen_q, head_size_q}, opts.dtype(at::kFloat)); - else - dq_accum = torch::empty({batch_size, num_heads, nsplits, seqlen_q, head_size_q}, opts.dtype(at::kFloat)); + // nsplits: deterministic mode splits dK into ceil(seqlen_k/16) pieces for atomic-free accumulation. + constexpr ck_tile::index_t kN0 = 16; + const ck_tile::index_t nsplits = deterministic + ? ck_tile::integer_divide_ceil(seqlen_k, kN0) + : 1; + // Always zero dq_accum: the dq_dk_dv kernel writes via atomicAdd regardless of + // deterministic mode, so an uninitialized accumulator would corrupt dQ. + at::Tensor dq_accum = torch::zeros({batch_size, num_heads, nsplits, seqlen_q, head_size_q}, opts.dtype(at::kFloat)); at::Tensor dk_expanded, dv_expanded; if (num_heads_k != num_heads) { // MQA / GQA @@ -198,6 +182,9 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v] hipLaunchKernelGGL( aiter::ParsePhiloxCudaState, dim3(1), dim3(64), 0, stream, philox_args, reinterpret_cast(rng_state.data_ptr())); + } else { + // No dropout: allocate a dummy tensor so data_ptr() is always valid. + rng_state = torch::empty({2}, opts.dtype(torch::kInt64)); } if (seqlen_q > 0) { @@ -299,6 +286,29 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v] nhead_stride_dbias = dbias.stride(2); } + void* sink_data_ptr = nullptr; + void* d_sink_data_ptr = nullptr; + if (sink_.has_value() && sink_.value().defined()) { + const auto& sink = sink_.value(); + CHECK_DEVICE(sink); + TORCH_CHECK(sink.dtype() == torch::kFloat32, "sink must be float32"); + TORCH_CHECK(sink.is_contiguous(), "sink must be contiguous"); + TORCH_CHECK(sink.dim() == 2 && sink.size(0) == batch_size && sink.size(1) == num_heads, + "sink must have shape [batch_size, num_heads]"); + sink_data_ptr = sink.data_ptr(); + } + if (d_sink_.has_value() && d_sink_.value().defined()) { + TORCH_CHECK(sink_data_ptr != nullptr, + "d_sink requires sink to also be provided"); + const auto& d_sink = d_sink_.value(); + CHECK_DEVICE(d_sink); + TORCH_CHECK(d_sink.dtype() == torch::kFloat32, "d_sink must be float32"); + TORCH_CHECK(d_sink.is_contiguous(), "d_sink must be contiguous"); + TORCH_CHECK(d_sink.dim() == 1 && d_sink.size(0) == num_heads, + "d_sink must have shape [num_heads]"); + d_sink_data_ptr = d_sink.data_ptr(); + } + return mha_bwd_args{false, // use_v3 false, // is_v3_atomic_fp32 false, // how_v3_bf16_cvt @@ -329,6 +339,8 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v] dv_expanded.data_ptr(), dbias_ptr, dq_accum.data_ptr(), + sink_data_ptr, // sink_ptr [b, hq] + d_sink_data_ptr, // d_sink_ptr [hq] nullptr, // seqstart_q_ptr nullptr, // seqstart_k_ptr nullptr, // seqlen_q_ptr diff --git a/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu b/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu index 19175dec56..90f7b57b43 100644 --- a/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu @@ -36,8 +36,9 @@ mha_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v] std::optional rng_state_, std::optional gen_, std::optional cu_seqlens_q_padded, // [b+1] - std::optional cu_seqlens_k_padded // [b+1] - ) + std::optional cu_seqlens_k_padded, // [b+1] + std::optional sink_, // [b, hq] log-space sink scores (float) + std::optional d_sink_) // [hq] sink gradient output (float) { if (is_causal) { window_size_right = 0; } @@ -149,37 +150,19 @@ mha_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v] bias_enum bias_type = alibi_slopes_.has_value() ? bias_enum::alibi : bias_enum::no_bias; auto opts = q.options(); - const fmha_bwd_traits traits{ - total_q, - total_k, - batch_size, - max_seqlen_q, - max_seqlen_k, - head_size_q, - head_size_v, - num_heads, - num_heads_k, - q_dtype_str, - true, // is_group_mode - mask.type, - bias_type, - false, // has_dbias - p_dropout > 0, - false, // is_store_randval - deterministic, - }; - const fmha_bwd_launcher launcher(traits); - const ck_tile::index_t nsplits = launcher.dq_acc_splits; + // nsplits: deterministic mode splits dK into ceil(max_seqlen_k/16) pieces for atomic-free accumulation. + constexpr ck_tile::index_t kN0 = 16; + const ck_tile::index_t nsplits = deterministic + ? ck_tile::integer_divide_ceil(max_seqlen_k, kN0) + : 1; const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard{q.device()}; auto stream = at::hip::getCurrentHIPStream(); auto softmax_d = torch::empty({batch_size, num_heads, total_q}, opts.dtype(at::kFloat)); - at::Tensor dq_accum; - if (launcher.needs_zero_dq_acc) - dq_accum = torch::zeros({num_heads, nsplits, total_q, head_size_q}, opts.dtype(at::kFloat)); - else - dq_accum = torch::empty({num_heads, nsplits, total_q, head_size_q}, opts.dtype(at::kFloat)); + // Always zero dq_accum: the dq_dk_dv kernel writes via atomicAdd regardless of + // deterministic mode, so an uninitialized accumulator would corrupt dQ. + at::Tensor dq_accum = torch::zeros({num_heads, nsplits, total_q, head_size_q}, opts.dtype(at::kFloat)); at::Tensor dk_expanded, dv_expanded; if (num_heads_k != num_heads) { // MQA / GQA @@ -311,6 +294,29 @@ mha_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v] seqstart_q_ptr = cu_seqlens_q.data_ptr(); } + void* sink_data_ptr = nullptr; + void* d_sink_data_ptr = nullptr; + if (sink_.has_value() && sink_.value().defined()) { + const auto& sink = sink_.value(); + CHECK_DEVICE(sink); + TORCH_CHECK(sink.dtype() == torch::kFloat32, "sink must be float32"); + TORCH_CHECK(sink.is_contiguous(), "sink must be contiguous"); + TORCH_CHECK(sink.dim() == 2 && sink.size(0) == batch_size && sink.size(1) == num_heads, + "sink must have shape [batch_size, num_heads]"); + sink_data_ptr = sink.data_ptr(); + } + if (d_sink_.has_value() && d_sink_.value().defined()) { + TORCH_CHECK(sink_data_ptr != nullptr, + "d_sink requires sink to also be provided"); + const auto& d_sink = d_sink_.value(); + CHECK_DEVICE(d_sink); + TORCH_CHECK(d_sink.dtype() == torch::kFloat32, "d_sink must be float32"); + TORCH_CHECK(d_sink.is_contiguous(), "d_sink must be contiguous"); + TORCH_CHECK(d_sink.dim() == 1 && d_sink.size(0) == num_heads, + "d_sink must have shape [num_heads]"); + d_sink_data_ptr = d_sink.data_ptr(); + } + return mha_bwd_args{false, // use_v3 false, // is_v3_atomic_fp32 false, // how_v3_bf16_cvt @@ -341,6 +347,8 @@ mha_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v] dv_expanded.data_ptr(), nullptr, // dbias dq_accum.data_ptr(), // dq_acc + sink_data_ptr, // sink_ptr [b, hq] + d_sink_data_ptr, // d_sink_ptr [hq] seqstart_q_ptr, // seqstart_q_ptr (physical cumulative) seqstart_k_ptr, // seqstart_k_ptr (physical cumulative) nullptr, // seqlen_q_ptr (per-sequence logical) diff --git a/csrc/py_itfs_cu/asm_mha_bwd.cu b/csrc/py_itfs_cu/asm_mha_bwd.cu index 52d898ad67..415640c694 100644 --- a/csrc/py_itfs_cu/asm_mha_bwd.cu +++ b/csrc/py_itfs_cu/asm_mha_bwd.cu @@ -277,6 +277,8 @@ std::vector fmha_v3_bwd(const at::Tensor &dout, // [b, sq, h dv_expanded.data_ptr(), nullptr, // dbias dq_accum.data_ptr(), + nullptr, // sink_ptr (not used in v3 asm path) + nullptr, // d_sink_ptr (not used in v3 asm path) nullptr, // seqstart_q_ptr (batch mode) nullptr, // seqstart_k_ptr (batch mode) nullptr, // seqlen_q_ptr (batch mode) diff --git a/csrc/py_itfs_cu/asm_mha_varlen_bwd.cu b/csrc/py_itfs_cu/asm_mha_varlen_bwd.cu index 28b7293564..c705c9962c 100644 --- a/csrc/py_itfs_cu/asm_mha_varlen_bwd.cu +++ b/csrc/py_itfs_cu/asm_mha_varlen_bwd.cu @@ -338,6 +338,8 @@ fmha_v3_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v dv_expanded.data_ptr(), nullptr, // dbias dq_accum.data_ptr(), // dq_acc + nullptr, // sink_ptr (not used in v3 asm path) + nullptr, // d_sink_ptr (not used in v3 asm path) seqstart_q_ptr, // seqstart_q seqstart_k_ptr, // seqstart_k nullptr, // seqlen_q_ptr diff --git a/csrc/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py b/csrc/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py index b09a0ccd77..beef57c5e1 100644 --- a/csrc/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py +++ b/csrc/py_itfs_cu/fmha_bwd_pre_post_kernel_generate.py @@ -94,6 +94,7 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype: str) -> Optional[dict] typename FmhaBwdTypeConfig::ODataType, typename FmhaBwdTypeConfig::OGradDataType, typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::LSEDataType, /* BlockSize = */ 64, {F_hdim}, {F_mode}, diff --git a/op_tests/cpp/mha/benchmark_mha_bwd.cpp b/op_tests/cpp/mha/benchmark_mha_bwd.cpp index fe36a8920d..0fbbf45c71 100644 --- a/op_tests/cpp/mha/benchmark_mha_bwd.cpp +++ b/op_tests/cpp/mha/benchmark_mha_bwd.cpp @@ -595,6 +595,8 @@ bool run(const ck_tile::ArgParser& arg_parser) dv_buf.GetDeviceBuffer(), dbias_buf.GetDeviceBuffer(), dq_acc_buf.GetDeviceBuffer(), + nullptr, // sink_ptr + nullptr, // d_sink_ptr seqstart_q.GetDeviceBuffer(), seqstart_k.GetDeviceBuffer(), nullptr, diff --git a/op_tests/test_mha.py b/op_tests/test_mha.py index 7c15eff72b..47b5342cfe 100644 --- a/op_tests/test_mha.py +++ b/op_tests/test_mha.py @@ -860,3 +860,225 @@ def test_flash_attn_seq_padding( df = pd.DataFrame(collected) aiter.logger.info(f"mha summary:\n{df}") + + +# --------------------------------------------------------------------------- +# Sink backward tests (mha_bwd with sink / d_sink) +# +# Reference formula (derived from kernel block_fmha_bwd_dot_do_o.hpp): +# D[b, h, q] = sum_j(dout[b, q, h, j] * out[b, q, h, j]) * p_undrop +# P_sink[b, h, q] = exp(sink[b, h] - lse_fwd[b, h, q]) +# d_sink[h] = sum_{b, q} (-P_sink[b, h, q] * D[b, h, q]) +# --------------------------------------------------------------------------- + + +def _sink_make_qkvo( + batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device +): + """Return (q, k, v, dout) in BSHD layout, requires_grad=True.""" + q = torch.randn( + batch, seqlen_q, nhead, hdim, device=device, dtype=dtype + ).requires_grad_(True) + k = torch.randn( + batch, seqlen_k, nhead_k, hdim, device=device, dtype=dtype + ).requires_grad_(True) + v = torch.randn( + batch, seqlen_k, nhead_k, hdim_v, device=device, dtype=dtype + ).requires_grad_(True) + dout = torch.randn(batch, seqlen_q, nhead, hdim_v, device=device, dtype=dtype) + return q, k, v, dout + + +def _sink_run_fwd(q, k, v, softmax_scale, causal): + """Run mha_fwd and return (out, lse).""" + out, lse, _, _ = aiter.mha_fwd( + q, + k, + v, + dropout_p=0.0, + softmax_scale=softmax_scale, + is_causal=causal, + window_size_left=-1, + window_size_right=0 if causal else -1, + sink_size=0, + return_softmax_lse=True, + return_dropout_randval=False, + ) + return out, lse + + +def _sink_reference_d_sink(dout, out, lse, sink, p_undrop=1.0): + """ + Pure-PyTorch reference for d_sink. + + dout : [B, Sq, H, Dv] + out : [B, Sq, H, Dv] + lse : [B, H, Sq] (forward LSE without sink) + sink : [B, H] + returns d_sink : [H] + """ + D_bsh = (dout.float() * out.float()).sum(dim=-1) * p_undrop # [B, Sq, H] + D_bhs = D_bsh.permute(0, 2, 1) # [B, H, Sq] + sink_bhs = sink.unsqueeze(-1) # [B, H, 1] + p_sink = torch.exp(sink_bhs.float() - lse.float()) # [B, H, Sq] + d_sink = (-p_sink * D_bhs).sum(dim=(0, 2)) # [H] + return d_sink.float() + + +_SINK_DTYPES = [dtypes.fp16, dtypes.bf16] +_SINK_CAUSALS = [False, True] +_SINK_CONFIGS = [ + # (batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim) + (2, 128, 128, 4, 4, 64), + (1, 64, 64, 6, 2, 128), +] + + +@pytest.mark.parametrize("causal", _SINK_CAUSALS) +@pytest.mark.parametrize("dtype", _SINK_DTYPES) +@pytest.mark.parametrize("batch,seqlen_q,seqlen_k,nhead,nhead_k,hdim", _SINK_CONFIGS) +def test_mha_bwd_sink_dsink( + batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal +): + """Verify that mha_bwd correctly accumulates d_sink.""" + device = torch.device("cuda") + hdim_v = hdim + softmax_scale = hdim**-0.5 + + q, k, v, dout = _sink_make_qkvo( + batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device + ) + out, lse = _sink_run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, causal) + + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_( + 30.0, 60.0 + ) + d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) + + dq, dk, dv, softmax_d = aiter.mha_bwd( + dout, + q.detach(), + k.detach(), + v.detach(), + out, + lse, + dropout_p=0.0, + softmax_scale=softmax_scale, + is_causal=causal, + window_size_left=-1, + window_size_right=0 if causal else -1, + deterministic=False, + sink=sink, + d_sink=d_sink, + ) + + assert d_sink.abs().max() > 0, "d_sink was not updated by mha_bwd" + + d_sink_ref = _sink_reference_d_sink(dout, out, lse, sink) + torch.testing.assert_close( + d_sink, + d_sink_ref, + rtol=0.02, + atol=0.5, + msg=f"d_sink mismatch for dtype={dtype}, causal={causal}, B={batch}, Sq={seqlen_q}, H={nhead}", + ) + + +@pytest.mark.parametrize("causal", _SINK_CAUSALS) +@pytest.mark.parametrize("dtype", _SINK_DTYPES) +@pytest.mark.parametrize("batch,seqlen_q,seqlen_k,nhead,nhead_k,hdim", _SINK_CONFIGS) +def test_mha_bwd_with_sink_dq_dk_dv( + batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, dtype, causal +): + """Verify that passing sink/d_sink does not corrupt the dQ, dK, dV outputs.""" + device = torch.device("cuda") + hdim_v = hdim + softmax_scale = hdim**-0.5 + + q, k, v, dout = _sink_make_qkvo( + batch, seqlen_q, seqlen_k, nhead, nhead_k, hdim, hdim_v, dtype, device + ) + out, lse = _sink_run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, causal) + + common_bwd_args = dict( + dropout_p=0.0, + softmax_scale=softmax_scale, + is_causal=causal, + window_size_left=-1, + window_size_right=0 if causal else -1, + deterministic=False, + ) + + dq_base, dk_base, dv_base, _ = aiter.mha_bwd( + dout, q.detach(), k.detach(), v.detach(), out, lse, **common_bwd_args + ) + + sink_small = torch.full((batch, nhead), -1000.0, device=device, dtype=torch.float32) + d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) + + dq_sink, dk_sink, dv_sink, _ = aiter.mha_bwd( + dout, + q.detach(), + k.detach(), + v.detach(), + out, + lse, + **common_bwd_args, + sink=sink_small, + d_sink=d_sink, + ) + + rtol, atol = (0.01, 0.01) if dtype == dtypes.fp16 else (0.02, 0.02) + torch.testing.assert_close( + dq_sink, dq_base, rtol=rtol, atol=atol, msg="dQ mismatch with small sink" + ) + torch.testing.assert_close( + dk_sink, dk_base, rtol=rtol, atol=atol, msg="dK mismatch with small sink" + ) + torch.testing.assert_close( + dv_sink, dv_base, rtol=rtol, atol=atol, msg="dV mismatch with small sink" + ) + + +@pytest.mark.parametrize("dtype", _SINK_DTYPES) +def test_mha_bwd_sink_null_gives_same_as_no_sink(dtype): + """Passing sink=None must give identical output to omitting sink entirely.""" + device = torch.device("cuda") + batch, seqlen, nhead, hdim = 2, 64, 4, 64 + softmax_scale = hdim**-0.5 + + q, k, v, dout = _sink_make_qkvo( + batch, seqlen, seqlen, nhead, nhead, hdim, hdim, dtype, device + ) + out, lse = _sink_run_fwd(q.detach(), k.detach(), v.detach(), softmax_scale, False) + + common = dict( + dropout_p=0.0, + softmax_scale=softmax_scale, + is_causal=False, + window_size_left=-1, + window_size_right=-1, + deterministic=False, + ) + + dq1, dk1, dv1, d1 = aiter.mha_bwd( + dout, q.detach(), k.detach(), v.detach(), out, lse, **common + ) + dq2, dk2, dv2, d2 = aiter.mha_bwd( + dout, + q.detach(), + k.detach(), + v.detach(), + out, + lse, + **common, + sink=None, + d_sink=None, + ) + + torch.testing.assert_close(dq1, dq2, msg="dQ differs with sink=None vs omitted") + torch.testing.assert_close(dk1, dk2, msg="dK differs with sink=None vs omitted") + torch.testing.assert_close(dv1, dv2, msg="dV differs with sink=None vs omitted") + torch.testing.assert_close( + d1, d2, msg="softmax_d differs with sink=None vs omitted" + ) diff --git a/op_tests/test_mha_varlen.py b/op_tests/test_mha_varlen.py index 5027753822..a523580184 100644 --- a/op_tests/test_mha_varlen.py +++ b/op_tests/test_mha_varlen.py @@ -1101,3 +1101,217 @@ def varlen_flash_attn_seq_padding_benchmark( df_padding = pd.DataFrame(padding_collected) aiter.logger.info(f"mha_varlen_seq_padding summary:\n{df_padding}") + + +# --------------------------------------------------------------------------- +# Sink backward tests (mha_varlen_bwd with sink / d_sink) +# --------------------------------------------------------------------------- + + +def _vsink_run_fwd(q, k, v, softmax_scale, causal): + """Run mha_fwd and return (out, lse).""" + out, lse, _, _ = aiter.mha_fwd( + q, + k, + v, + dropout_p=0.0, + softmax_scale=softmax_scale, + is_causal=causal, + window_size_left=-1, + window_size_right=0 if causal else -1, + sink_size=0, + return_softmax_lse=True, + return_dropout_randval=False, + ) + return out, lse + + +def _vsink_reference_d_sink_varlen(dout, out, lse_group, sink, seqlens_q): + """ + Reference d_sink for varlen mode. + + dout : [total_q, H, Dv] + out : [total_q, H, Dv] + lse_group : [H, total_q] – group-mode LSE (flattened across batches) + sink : [B, H] + seqlens_q : list of per-batch sequence lengths + returns d_sink : [H] + """ + nhead = sink.shape[1] + d_sink = torch.zeros(nhead, device=sink.device, dtype=torch.float32) + + offset = 0 + for b, sq in enumerate(seqlens_q): + dout_b = dout[offset : offset + sq].float() + out_b = out[offset : offset + sq].float() + lse_b = lse_group[:, offset : offset + sq] + + D_qh = (dout_b * out_b).sum(dim=-1) + D_hq = D_qh.permute(1, 0) + p_sink = torch.exp(sink[b].float().unsqueeze(-1) - lse_b) + d_sink += (-p_sink * D_hq).sum(dim=-1) + offset += sq + + return d_sink + + +_VSINK_DTYPES = [dtypes.fp16, dtypes.bf16] + + +@pytest.mark.parametrize("dtype", _VSINK_DTYPES) +def test_mha_varlen_bwd_sink_dsink(dtype): + """Numerical correctness test: mha_varlen_bwd with sink/d_sink (equal-length sequences).""" + device = torch.device("cuda") + batch, seqlen, nhead, hdim = 2, 64, 4, 64 + hdim_v = hdim + softmax_scale = hdim**-0.5 + seqlens_q = [seqlen] * batch + + cu_seqlens_q = torch.tensor( + [0, seqlen, seqlen * 2], device=device, dtype=torch.int32 + ) + cu_seqlens_k = cu_seqlens_q.clone() + total_q = seqlen * batch + total_k = seqlen * batch + + q = torch.randn(total_q, nhead, hdim, device=device, dtype=dtype) + k = torch.randn(total_k, nhead, hdim, device=device, dtype=dtype) + v = torch.randn(total_k, nhead, hdim_v, device=device, dtype=dtype) + dout = torch.randn(total_q, nhead, hdim_v, device=device, dtype=dtype) + + q_b = q.view(batch, seqlen, nhead, hdim) + k_b = k.view(batch, seqlen, nhead, hdim) + v_b = v.view(batch, seqlen, nhead, hdim_v) + out_b, lse_b = _vsink_run_fwd(q_b, k_b, v_b, softmax_scale, causal=False) + + out = out_b.view(total_q, nhead, hdim_v) + lse = lse_b.permute(1, 0, 2).reshape(nhead, total_q).contiguous() + + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_( + 30.0, 60.0 + ) + d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) + + dq, dk, dv, _ = aiter.mha_varlen_bwd( + dout, + q, + k, + v, + out, + lse, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=seqlen, + max_seqlen_k=seqlen, + dropout_p=0.0, + softmax_scale=softmax_scale, + zero_tensors=False, + is_causal=False, + window_size_left=-1, + window_size_right=-1, + deterministic=False, + sink=sink, + d_sink=d_sink, + ) + + assert torch.isfinite(d_sink).all(), f"d_sink contains non-finite values: {d_sink}" + assert d_sink.abs().max() > 0, "mha_varlen_bwd did not update d_sink" + assert dq.shape == q.shape + assert dk.shape == k.shape + assert dv.shape == v.shape + + d_sink_ref = _vsink_reference_d_sink_varlen(dout, out, lse, sink, seqlens_q) + torch.testing.assert_close( + d_sink, + d_sink_ref, + rtol=0.02, + atol=0.5, + msg="varlen d_sink mismatch vs reference", + ) + + +@pytest.mark.parametrize("dtype", _VSINK_DTYPES) +def test_mha_varlen_bwd_sink_variable_lengths(dtype): + """Varlen sink test with variable-length sequences per batch entry.""" + device = torch.device("cuda") + nhead, hdim = 4, 64 + hdim_v = hdim + softmax_scale = hdim**-0.5 + + seqlens_q = [48, 80] + seqlens_k = [48, 80] + batch = len(seqlens_q) + max_seqlen_q = max(seqlens_q) + max_seqlen_k = max(seqlens_k) + total_q = sum(seqlens_q) + total_k = sum(seqlens_k) + + cu_sq = torch.tensor( + [0] + list(torch.cumsum(torch.tensor(seqlens_q), 0).tolist()), + device=device, + dtype=torch.int32, + ) + cu_sk = torch.tensor( + [0] + list(torch.cumsum(torch.tensor(seqlens_k), 0).tolist()), + device=device, + dtype=torch.int32, + ) + + q = torch.randn(total_q, nhead, hdim, device=device, dtype=dtype) + k = torch.randn(total_k, nhead, hdim, device=device, dtype=dtype) + v = torch.randn(total_k, nhead, hdim_v, device=device, dtype=dtype) + dout = torch.randn(total_q, nhead, hdim_v, device=device, dtype=dtype) + + out_parts, lse_parts = [], [] + offset_q, offset_k = 0, 0 + for sq, sk in zip(seqlens_q, seqlens_k): + q_b = q[offset_q : offset_q + sq].unsqueeze(0) + k_b = k[offset_k : offset_k + sk].unsqueeze(0) + v_b = v[offset_k : offset_k + sk].unsqueeze(0) + out_b, lse_b = _vsink_run_fwd(q_b, k_b, v_b, softmax_scale, causal=False) + out_parts.append(out_b.squeeze(0)) + lse_parts.append(lse_b.squeeze(0).permute(1, 0)) + offset_q += sq + offset_k += sk + + out = torch.cat(out_parts, dim=0) + lse = torch.cat(lse_parts, dim=0).permute(1, 0).contiguous() + + sink = torch.empty(batch, nhead, device=device, dtype=torch.float32).uniform_( + 30.0, 60.0 + ) + d_sink = torch.zeros(nhead, device=device, dtype=torch.float32) + + dq, dk, dv, _ = aiter.mha_varlen_bwd( + dout, + q, + k, + v, + out, + lse, + cu_seqlens_q=cu_sq, + cu_seqlens_k=cu_sk, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=0.0, + softmax_scale=softmax_scale, + zero_tensors=False, + is_causal=False, + window_size_left=-1, + window_size_right=-1, + deterministic=False, + sink=sink, + d_sink=d_sink, + ) + + assert torch.isfinite(d_sink).all(), f"d_sink has non-finite values: {d_sink}" + assert d_sink.abs().max() > 0, "mha_varlen_bwd did not update d_sink" + + d_sink_ref = _vsink_reference_d_sink_varlen(dout, out, lse, sink, seqlens_q) + torch.testing.assert_close( + d_sink, + d_sink_ref, + rtol=0.02, + atol=0.5, + msg="varlen variable-length d_sink mismatch", + )