Skip to content

Commit

Permalink
Add a new SdpaFwdOp IR node for Flash Attention (#2294)
Browse files Browse the repository at this point in the history
Based on the PR discussions, this PR is repurposed to introduce a new IR
node `SdpaFwdOp` for scaled dot product flash attention forward (see
#2278 for details).
This PR does not include changes to the scheduler.
  • Loading branch information
Priya2698 authored Jun 10, 2024
1 parent 6a4156d commit 23ee81d
Show file tree
Hide file tree
Showing 12 changed files with 605 additions and 10 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ list(APPEND JIT_TEST_SRCS
${NVFUSER_ROOT}/tests/cpp/test_resize.cpp
${NVFUSER_ROOT}/tests/cpp/test_reduction_pointwise.cpp
${NVFUSER_ROOT}/tests/cpp/test_scalar_hoisting.cpp
${NVFUSER_ROOT}/tests/cpp/test_sdpa_node.cpp
${NVFUSER_ROOT}/tests/cpp/test_segmentation.cpp
${NVFUSER_ROOT}/tests/cpp/test_serial_gridreduce.cpp
${NVFUSER_ROOT}/tests/cpp/test_sharding.cpp
Expand Down
1 change: 1 addition & 0 deletions csrc/device_lower/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ bool isTvOp(const Expr* expr) {
MatmulOp,
MmaOp,
LinearOp,
SdpaFwdOp,
BroadcastOp,
SqueezeOp,
ExpandOp,
Expand Down
1 change: 1 addition & 0 deletions csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class Val;
f(Resize); \
f(MatmulOp); \
f(LinearOp); \
f(SdpaFwdOp); \
f(Communication);
#define DISPATCH_FOR_ALL_KIR_EXPRS(f) \
f(Allocate); \
Expand Down
97 changes: 97 additions & 0 deletions csrc/ir/internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -2201,4 +2201,101 @@ class LinearOp : public Expr {
}
};

/*
SDPA node with same functionality at::_scaled_dot_product_flash_attention
output = [N, H, L, Ev]
logsumexp = [N, H, L]
cum_seq_q = [N + 1,]
cum_seq_k = [N + 1,]
query_seq_len = scalar(int)
key_seq_len = scalar(int)
philox_seed = scalar tensor
philox_offset = scalar tensor
debug_attn_mask = scalar tensor (Thunder does not return a debug attn mask by
setting `return_debug_mask=False` when invoking flash attention)
query = [N, H, L, E]
key = [N, H, S, E]
value = [N, H, S, Ev]
dropout_p = scalar(double)
is_causal = scalar(bool)
scale = scalar(double)
N = number of sequences / batch size
H = num of heads
L = query sequence length / target sequence length
S = key/value sequence length / src sequence length
E = query/key embd dimension
Ev = value embd dimension
For flash attention, E = Ev
*/

class SdpaFwdOp : public Expr {
public:
using Expr::Expr;

SdpaFwdOp(
IrBuilderPasskey,
TensorView* output,
TensorView* log_sumexp,
TensorView* cum_seq_q,
TensorView* cum_seq_k,
Val* query_seq_len,
Val* key_seq_len,
TensorView* philox_seed,
TensorView* philox_offset,
TensorView* debug_attn_mask,
Val* query,
Val* key,
Val* value,
Val* dropout_p,
Val* is_causal,
Val* scale);

NVFUSER_DECLARE_CLONE_AND_CREATE

const char* getOpString() const override {
return "SdpaFwdOp";
}

std::string toString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;

Val* attn_out() const {
return output(0);
}

Val* query() const {
return input(0);
}

Val* key() const {
return input(1);
}

Val* value() const {
return input(2);
}

Val* dropout_p() const {
return input(3);
}

Val* is_causal() const {
return input(4);
}

Val* scale() const {
if (inputs().size() > 5) {
return input(5);
}
return nullptr;
}

std::vector<PolymorphicValue> evaluate(
const ExpressionEvaluator& ee,
const std::vector<PolymorphicValue>& inputs) const override;
};

} // namespace nvfuser
133 changes: 133 additions & 0 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4260,4 +4260,137 @@ std::vector<PolymorphicValue> LinearOp::evaluate(
return {at::linear(a, b)};
}

SdpaFwdOp::SdpaFwdOp(
IrBuilderPasskey passkey,
TensorView* output,
TensorView* log_sumexp,
TensorView* cum_seq_q,
TensorView* cum_seq_k,
Val* query_seq_len,
Val* key_seq_len,
TensorView* philox_seed,
TensorView* philox_offset,
TensorView* debug_attn_mask,
Val* query,
Val* key,
Val* value,
Val* dropout_p,
Val* is_causal,
Val* scale)
: Expr(passkey) {
addOutput(output);
addOutput(log_sumexp);
addOutput(cum_seq_q);
addOutput(cum_seq_k);
addOutput(query_seq_len);
addOutput(key_seq_len);
addOutput(philox_seed);
addOutput(philox_offset);
addOutput(debug_attn_mask);

addInput(query);
addInput(key);
addInput(value);
addInput(dropout_p);
addInput(is_causal);
if (scale != nullptr) {
addInput(scale);
}
}

NVFUSER_DEFINE_CLONE_AND_CREATE(SdpaFwdOp)

std::string SdpaFwdOp::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << attn_out()->toString() << "\n";
indent(ss, indent_size + 1) << " = sdpa(" << query()->toString() << ",\n";
indent(ss, indent_size + 1) << " " << key()->toString() << ",\n";
indent(ss, indent_size + 1) << " " << value()->toString() << ",\n";
indent(ss, indent_size + 1)
<< " dropout_p = " << dropout_p()->toInlineString() << ",\n";
indent(ss, indent_size + 1)
<< " is_causal = " << is_causal()->toInlineString();
if (scale() != nullptr) {
indent(ss, indent_size + 1)
<< ",\n scale = " << scale()->toInlineString();
}
indent(ss, indent_size + 1) << ")\n";
return ss.str();
}

std::string SdpaFwdOp::toInlineString(int indent_size) const {
NVF_CHECK(false, "Tensor op can not be printed inline");
}

std::vector<PolymorphicValue> SdpaFwdOp::evaluate(
const ExpressionEvaluator& ee,
const std::vector<PolymorphicValue>& inputs) const {
auto query = inputs.at(0).as<at::Tensor>();
auto key = inputs.at(1).as<at::Tensor>();
auto value = inputs.at(2).as<at::Tensor>();

const auto dropout_p = inputs.at(3).as<double>();
const auto is_causal = inputs.at(4).as<bool>();

// Flash attention requires the last dimension to be padded to 8.
// https://github.com/pytorch/pytorch/blob/c27882ffa8c1c7e4cf8ebc6c2f879e5b6c8814ad/aten/src/ATen/native/transformers/attention.cpp#L675-L677
const auto last_dim_size = query.sizes()[3];
auto pad_last_dim = [last_dim_size](
at::Tensor inp, int alignment_size) -> at::Tensor {
if (last_dim_size % alignment_size == 0) {
return inp;
}
auto pad_count = alignment_size - (last_dim_size % alignment_size);
auto padded_inp = at::pad(inp, {0, pad_count});
return padded_inp;
};

query = pad_last_dim(query, 8);
key = pad_last_dim(key, 8);
value = pad_last_dim(value, 8);

// Conmpute scale using original size of last dimension
double scale = inputs.size() > 5 ? inputs.back().as<double>()
: 1.0 / std::sqrt(last_dim_size);

// ATen reference:
// https://github.com/pytorch/pytorch/blob/c27882ffa8c1c7e4cf8ebc6c2f879e5b6c8814ad/aten/src/ATen/native/transformers/attention.cpp#L680-L681
auto
[output,
log_sumexp,
cum_seq_q,
cum_seq_k,
query_seq_len,
key_seq_len,
philox_seed,
philox_offset,
debug_attn_mask] =
at::_scaled_dot_product_flash_attention(
query,
key,
value,
dropout_p,
is_causal,
/*return_debug_mask=*/false,
scale);

// If the inputs were padded, slice the output to restore the original size
if (output.sizes()[3] != last_dim_size) {
output = output.slice(-1, 0, last_dim_size);
}

// Query and key seq len are of type c10::SymInt -> convert them to int for
// Polymorphic Value
return {
output,
log_sumexp,
cum_seq_q,
cum_seq_k,
*query_seq_len.maybe_as_int(),
*key_seq_len.maybe_as_int(),
philox_seed,
philox_offset,
debug_attn_mask};
}

} // namespace nvfuser
Loading

0 comments on commit 23ee81d

Please sign in to comment.