Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/composable_kernel
Submodule composable_kernel updated 105 files
55 changes: 37 additions & 18 deletions aiter/ops/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,8 @@ def cmdGenFunc_mha_bwd(
alibi_slopes: Optional[Tensor] = None,
rng_state: Optional[Tensor] = None,
gen: Optional[Generator] = None,
sink_ptr: Optional[Tensor] = None,
sink: Optional[Tensor] = None,
d_sink: Optional[Tensor] = None,
):
md_name = "mha_bwd"
filter1 = "*" # get_bwd_dot_do_o_blobs()
Expand Down Expand Up @@ -775,7 +776,8 @@ def gen_mha_bwd_fake_tensors(
alibi_slopes: Optional[Tensor] = None,
rng_state: Optional[Tensor] = None,
gen: Optional[Generator] = None,
sink_ptr: Optional[Tensor] = None,
sink: Optional[Tensor] = None,
d_sink: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
return common_mha_bwd_fake_tensors(q, k, v, dq, dk, dv)

Expand Down Expand Up @@ -807,7 +809,8 @@ def mha_bwd(
alibi_slopes: Optional[Tensor] = None,
rng_state: Optional[Tensor] = None,
gen: Optional[Generator] = None,
sink_ptr: Optional[Tensor] = None,
sink: Optional[Tensor] = None,
d_sink: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ...


Expand Down Expand Up @@ -889,7 +892,8 @@ def cmdGenFunc_mha_varlen_bwd(
gen: Optional[Generator] = None,
cu_seqlens_q_padded: Optional[Tensor] = None,
cu_seqlens_k_padded: Optional[Tensor] = None,
sink_ptr: Optional[Tensor] = None,
sink: Optional[Tensor] = None,
d_sink: Optional[Tensor] = None,
) -> dict[str, Any]:
md_name = "mha_varlen_bwd"
filter1 = "*" # get_bwd_dot_do_o_blobs()
Expand Down Expand Up @@ -1117,7 +1121,8 @@ def gen_mha_varlen_bwd_fake_tensors(
alibi_slopes: Optional[Tensor] = None,
rng_state: Optional[Tensor] = None,
gen: Optional[Generator] = None,
sink_ptr: Optional[Tensor] = None,
sink: Optional[Tensor] = None,
d_sink: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
return gen_mha_varlen_bwd_fake_tensors_common(
q, k, v, cu_seqlens_q, max_seqlen_q, zero_tensors, dq, dk, dv
Expand Down Expand Up @@ -1157,7 +1162,8 @@ def mha_varlen_bwd(
gen: Optional[Generator] = None,
cu_seqlens_q_padded: Optional[Tensor] = None,
cu_seqlens_k_padded: Optional[Tensor] = None,
sink_ptr: Optional[Tensor] = None,
sink: Optional[Tensor] = None,
d_sink: Optional[Tensor] = None,
Comment thread
LJ-underdog marked this conversation as resolved.
Comment thread
LJ-underdog marked this conversation as resolved.
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ...


Expand Down Expand Up @@ -1567,7 +1573,8 @@ def _flash_attn_backward_fake(
rng_state: Optional[torch.Tensor] = None,
is_v3_atomic_fp32: Optional[bool] = True,
how_v3_bf16_cvt: Optional[int] = 1,
sink_ptr: Optional[Tensor] = None,
sink: Optional[Tensor] = None,
d_sink: Optional[Tensor] = None,
) -> torch.Tensor:
batch_size = q.size(0)
seqlen_q = q.size(1)
Expand Down Expand Up @@ -1606,7 +1613,8 @@ def _flash_attn_backward(
rng_state: Optional[torch.Tensor] = None,
is_v3_atomic_fp32: Optional[bool] = True,
how_v3_bf16_cvt: Optional[int] = 1,
sink_ptr: Optional[Tensor] = None,
sink: Optional[Tensor] = None,
d_sink: Optional[Tensor] = None,
) -> torch.Tensor:
# rtna & rtz are deprecated in gfx950
if get_gfx() == "gfx950" and how_v3_bf16_cvt != 0:
Expand Down Expand Up @@ -1723,7 +1731,8 @@ def can_impl_fmha_v3_bwd_gfx950():
alibi_slopes,
rng_state,
None,
sink_ptr,
sink,
d_sink,
# custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd},
)
return softmax_d
Expand Down Expand Up @@ -1782,7 +1791,7 @@ def forward(
how_v3_bf16_cvt=how_v3_bf16_cvt,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
sink_ptr=sink_ptr,
sink_ptr=sink_ptr, # fwd kernel still uses sink_ptr naming
)
if is_grad:
assert return_lse
Expand Down Expand Up @@ -1840,7 +1849,8 @@ def backward(ctx, dout, *args):
rng_state,
ctx.is_v3_atomic_fp32,
ctx.how_v3_bf16_cvt,
sink_ptr=None,
sink=None,
d_sink=None,
)
dq = dq[..., :head_size_q_og] # We could have padded the head dimension
dk = dk[..., :head_size_q_og]
Expand All @@ -1863,7 +1873,10 @@ def backward(ctx, dout, *args):
# 15 how_v3_bf16_cvt
# 16 cu_seqlens_q
# 17 cu_seqlens_kv
# Need to return exactly 17 gradient entries.
# 18 sink_ptr (fwd-only sink scores; not differentiable via autograd.
# bwd sink gradient d_sink is computed inside mha_bwd kernel,
# not returned here as a positional gradient.)
# Need to return exactly 18 gradient entries.
return (
dq, # q
dk, # k
Expand All @@ -1882,7 +1895,7 @@ def backward(ctx, dout, *args):
None, # how_v3_bf16_cvt
None, # cu_seqlens_q
None, # cu_seqlens_kv
None, # sink_ptr
None, # sink_ptr (not differentiable; bwd uses sink/d_sink args separately)
)


Expand Down Expand Up @@ -2170,7 +2183,8 @@ def _flash_attn_varlen_backward(
zero_tensors: bool = False,
cu_seqlens_q_padded: Optional[torch.Tensor] = None,
cu_seqlens_k_padded: Optional[torch.Tensor] = None,
sink_ptr: Optional[Tensor] = None,
sink: Optional[Tensor] = None,
d_sink: Optional[Tensor] = None,
) -> torch.Tensor:

_, nhead_q, hdim_q = q.shape
Expand Down Expand Up @@ -2332,7 +2346,8 @@ def can_impl_fmha_v3_bwd_gfx950():
None,
cu_seqlens_q_padded,
cu_seqlens_k_padded,
sink_ptr=sink_ptr,
sink=sink,
d_sink=d_sink,
# custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd},
)
return softmax_d
Expand Down Expand Up @@ -2487,7 +2502,8 @@ def backward(ctx, dout, *args):
how_v3_bf16_cvt=ctx.how_v3_bf16_cvt,
cu_seqlens_q_padded=ctx.cu_seqlens_q_padded,
cu_seqlens_k_padded=ctx.cu_seqlens_k_padded,
sink_ptr=None,
sink=None,
d_sink=None,
)
dq = dq[..., :head_size_q_og] # We could have padded the head dimension
dk = dk[..., :head_size_q_og]
Expand All @@ -2508,7 +2524,10 @@ def backward(ctx, dout, *args):
# out,
# is_grad_enabled,
# cu_seqlens_q_padded, cu_seqlens_k_padded,
# is_v3_atomic_fp32, how_v3_bf16_cvt
# is_v3_atomic_fp32, how_v3_bf16_cvt,
# sink_ptr (fwd-only sink scores; not differentiable via autograd.
# bwd sink gradient d_sink is computed inside mha_varlen_bwd kernel,
# not returned here as a positional gradient.)
# We only have gradients for q,k,v (dq,dk,dv) and possibly bias (dbias). Others are None.
return (
dq, # q
Expand Down Expand Up @@ -2536,7 +2555,7 @@ def backward(ctx, dout, *args):
None, # cu_seqlens_k_padded
None, # is_v3_atomic_fp32
None, # how_v3_bf16_cvt
None, # sink_ptr
None, # sink_ptr (not differentiable; bwd uses sink/d_sink args separately)
)


Expand Down
6 changes: 4 additions & 2 deletions csrc/cpp_itfs/mha_bwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ float mha_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
/* dv_ptr */ a.dv_ptr,
/* dbias_ptr */ a.dbias_ptr,
/* dq_acc_ptr */ a.dq_acc_ptr,
/* sink_ptr */ a.sink_ptr,
/* d_sink_ptr */ a.d_sink_ptr,

/* seqstart_q_ptr */ a.seqstart_q_ptr,
/* seqstart_k_ptr */ a.seqstart_k_ptr,
Expand Down Expand Up @@ -208,7 +210,7 @@ float mha_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
/* nhead_stride_randval*/ a.nhead_stride_randval,
/* nhead_stride_do */ a.nhead_stride_do,
/* nhead_stride_lsed */ a.nhead_stride_lsed,
/* nhead_stride_dq_acc*/ a.nhead_stride_dq_acc,
/* nhead_stride_dq_acc*/ static_cast<ck_tile::long_index_t>(a.nhead_stride_dq_acc),
/* nhead_stride_dq */ a.nhead_stride_dq,
/* nhead_stride_dk */ a.nhead_stride_dk,
/* nhead_stride_dv */ a.nhead_stride_dv,
Expand All @@ -222,7 +224,7 @@ float mha_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
/* batch_stride_randval*/ a.batch_stride_randval,
/* batch_stride_do */ a.batch_stride_do,
/* batch_stride_lsed */ a.batch_stride_lsed,
/* batch_stride_dq_acc*/ a.batch_stride_dq_acc,
/* batch_stride_dq_acc*/ static_cast<ck_tile::long_index_t>(a.batch_stride_dq_acc),
/* batch_stride_dq */ a.batch_stride_dq,
/* batch_stride_dk */ a.batch_stride_dk,
/* batch_stride_dv */ a.batch_stride_dv,
Expand Down
2 changes: 2 additions & 0 deletions csrc/include/mha_bwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ struct mha_bwd_args
void* dv_ptr;
void* dbias_ptr;
void* dq_acc_ptr;
const void* sink_ptr = nullptr; // sink scores [batch, nhead] log-space (LSEDataType=float); nullptr disables sink
void* d_sink_ptr = nullptr; // sink gradient accumulator [nhead] (LSEDataType=float); nullptr disables sink grad
Comment thread
LJ-underdog marked this conversation as resolved.
// Usage notes for sequence length pointer parameters:
//
// [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles
Expand Down
8 changes: 6 additions & 2 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,9 @@ namespace py = pybind11;
py::arg("bias") = std::nullopt, \
py::arg("alibi_slopes") = std::nullopt, \
py::arg("rng_state") = std::nullopt, \
py::arg("gen") = std::nullopt);
py::arg("gen") = std::nullopt, \
py::arg("sink") = std::nullopt, \
py::arg("d_sink") = std::nullopt);

#define MHA_FWD_ASM_PYBIND \
m.def("fmha_v3_fwd", \
Expand Down Expand Up @@ -950,7 +952,9 @@ namespace py = pybind11;
py::arg("rng_state") = std::nullopt, \
py::arg("gen") = std::nullopt, \
py::arg("cu_seqlens_q_padded") = std::nullopt, \
py::arg("cu_seqlens_k_padded") = std::nullopt);
py::arg("cu_seqlens_k_padded") = std::nullopt, \
py::arg("sink") = std::nullopt, \
py::arg("d_sink") = std::nullopt);

#define MOE_CK_2STAGES_PYBIND \
m.def("ck_moe_stage1", \
Expand Down
4 changes: 3 additions & 1 deletion csrc/include/torch/mha_bwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ std::vector<at::Tensor> mha_bwd(const at::Tensor& dout, // [b, sq, hq, d]
std::optional<const at::Tensor> bias_, // [sq, sk]
std::optional<const at::Tensor> alibi_slopes, // [hq] or [b, hq]
std::optional<const at::Tensor> rng_state,
std::optional<at::Generator> gen);
std::optional<at::Generator> gen,
std::optional<const at::Tensor> sink, // [b, hq] log-space sink scores (float)
std::optional<at::Tensor> d_sink); // [hq] sink gradient output (float)
} // namespace torch_itfs
} // namespace aiter
4 changes: 3 additions & 1 deletion csrc/include/torch/mha_varlen_bwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ mha_varlen_bwd(const at::Tensor& dout, // [total_q, hq, d]
std::optional<const at::Tensor> rng_state,
std::optional<at::Generator> gen,
std::optional<const at::Tensor> cu_seqlens_q_padded, // [b+1]
std::optional<const at::Tensor> cu_seqlens_k_padded // [b+1]
std::optional<const at::Tensor> cu_seqlens_k_padded, // [b+1]
std::optional<const at::Tensor> sink, // [b, hq] log-space sink scores (float)
std::optional<at::Tensor> d_sink // [hq] sink gradient output (float)
);
} // namespace torch_itfs
} // namespace aiter
66 changes: 39 additions & 27 deletions csrc/py_itfs_ck/mha_bwd_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v]
std::optional<const at::Tensor> bias_, // [sq, sk]
std::optional<const at::Tensor> alibi_slopes_, // [hq] or [b, hq]
std::optional<const at::Tensor> rng_state_,
std::optional<at::Generator> gen_)
std::optional<at::Generator> gen_,
std::optional<const at::Tensor> sink_, // [b, hq] log-space sink scores (float)
std::optional<at::Tensor> d_sink_) // [hq] sink gradient output (float)
Comment thread
LJ-underdog marked this conversation as resolved.
{
if (is_causal) { window_size_right = 0; }

Expand Down Expand Up @@ -131,37 +133,19 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v]
alibi_slopes_.has_value() ? bias_type = bias_enum::alibi : bias_enum::no_bias;
bool has_dbias = dbias_.has_value();
auto opts = q.options();
const fmha_bwd_traits traits{
seqlen_q,
seqlen_k,
batch_size,
seqlen_q, // max_seqlen_q
seqlen_k, // max_seqlen_k
head_size_q,
head_size_v,
num_heads,
num_heads_k,
q_dtype_str,
false, // is_group_mode
mask.type,
bias_type,
has_dbias,
p_dropout > 0,
false, // is_store_randval
deterministic,
};

const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard{q.device()};
auto stream = at::hip::getCurrentHIPStream();

auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
const fmha_bwd_launcher launcher(traits);
const ck_tile::index_t nsplits = launcher.dq_acc_splits;
at::Tensor dq_accum;
if (launcher.needs_zero_dq_acc)
dq_accum = torch::zeros({batch_size, num_heads, nsplits, seqlen_q, head_size_q}, opts.dtype(at::kFloat));
else
dq_accum = torch::empty({batch_size, num_heads, nsplits, seqlen_q, head_size_q}, opts.dtype(at::kFloat));
// nsplits: deterministic mode splits dK into ceil(seqlen_k/16) pieces for atomic-free accumulation.
constexpr ck_tile::index_t kN0 = 16;
const ck_tile::index_t nsplits = deterministic
? ck_tile::integer_divide_ceil(seqlen_k, kN0)
: 1;
// Always zero dq_accum: the dq_dk_dv kernel writes via atomicAdd regardless of
// deterministic mode, so an uninitialized accumulator would corrupt dQ.
at::Tensor dq_accum = torch::zeros({batch_size, num_heads, nsplits, seqlen_q, head_size_q}, opts.dtype(at::kFloat));

at::Tensor dk_expanded, dv_expanded;
if (num_heads_k != num_heads) { // MQA / GQA
Expand Down Expand Up @@ -198,6 +182,9 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v]
hipLaunchKernelGGL(
aiter::ParsePhiloxCudaState, dim3(1), dim3(64), 0, stream,
philox_args, reinterpret_cast<uint64_t*>(rng_state.data_ptr()));
} else {
// No dropout: allocate a dummy tensor so data_ptr() is always valid.
rng_state = torch::empty({2}, opts.dtype(torch::kInt64));
}

if (seqlen_q > 0) {
Expand Down Expand Up @@ -299,6 +286,29 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v]
nhead_stride_dbias = dbias.stride(2);
}

void* sink_data_ptr = nullptr;
void* d_sink_data_ptr = nullptr;
if (sink_.has_value() && sink_.value().defined()) {
const auto& sink = sink_.value();
CHECK_DEVICE(sink);
TORCH_CHECK(sink.dtype() == torch::kFloat32, "sink must be float32");
TORCH_CHECK(sink.is_contiguous(), "sink must be contiguous");
TORCH_CHECK(sink.dim() == 2 && sink.size(0) == batch_size && sink.size(1) == num_heads,
"sink must have shape [batch_size, num_heads]");
sink_data_ptr = sink.data_ptr();
}
if (d_sink_.has_value() && d_sink_.value().defined()) {
TORCH_CHECK(sink_data_ptr != nullptr,
"d_sink requires sink to also be provided");
const auto& d_sink = d_sink_.value();
CHECK_DEVICE(d_sink);
TORCH_CHECK(d_sink.dtype() == torch::kFloat32, "d_sink must be float32");
TORCH_CHECK(d_sink.is_contiguous(), "d_sink must be contiguous");
TORCH_CHECK(d_sink.dim() == 1 && d_sink.size(0) == num_heads,
"d_sink must have shape [num_heads]");
d_sink_data_ptr = d_sink.data_ptr();
}

return mha_bwd_args{false, // use_v3
false, // is_v3_atomic_fp32
false, // how_v3_bf16_cvt
Expand Down Expand Up @@ -329,6 +339,8 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v]
dv_expanded.data_ptr(),
dbias_ptr,
dq_accum.data_ptr(),
sink_data_ptr, // sink_ptr [b, hq]
d_sink_data_ptr, // d_sink_ptr [hq]
nullptr, // seqstart_q_ptr
nullptr, // seqstart_k_ptr
nullptr, // seqlen_q_ptr
Expand Down
Loading
Loading