New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GPU] Add initial SDPA implementation #24466
Conversation
75c4dd0
to
280cc81
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@itikhono, @sshlyapn,
do we expect that https://www.tensorflow.org/api_docs/python/tf/keras/layers/Attention that corresponds to SDPA will be fused into SDPA internal operation on CPU and GPU?
@rkazants I think it makes sense, but currently not all the cases are well optimized, so if it will have an ability to decompose it back to current representation for some cases, sounds good |
c887523
to
63fd7cc
Compare
* @brief Turning on this key disables SDPA operation decomposition and keeps SDPA operation in the graph. | ||
* Enabling SDPA optimization may provide performance improvements and memory usage reduction. | ||
* This key serves as a recommendation and may be ignored in known sub-optimal cases. | ||
* @ingroup ov_runtime_ocl_gpu_prop_cpp_api |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the default value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ilya-lavrenov currently it's disabled by default. However, in the final version, it will depend on whether support for indirect inputs is implemented for SDPA in time or not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This allows to switch on for models where indirect inputs are not required
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No critical issues found so far. Feel free to follow up the comments in the next PR
return params; | ||
} | ||
|
||
static std::unique_ptr<primitive_impl> create(const typed_program_node<scaled_dot_product_attention>& arg, const kernel_impl_params& impl_param) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this impl have any difference with common version in base class?
CeilDiv(target_seq_len, target_seq_len_block_size), | ||
head_size * num_of_partitions }; | ||
dispatch_data.lws = { 1, 1, head_size }; | ||
} else if (kernel_idx == 2) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably it should be kernel_idx == KernelsTypes::FINALIZATION
kernel_impl_params const& impl_param) { | ||
auto desc = impl_param.typed_desc<scaled_dot_product_attention>(); | ||
|
||
return impl_param.get_input_layout(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose this impl is incorrect, so maybe just throw unimplemented exception here intead?
} else if (pattern_map.find(sdpa_with_attn_mask_m) != pattern_map.end()) { | ||
sdpa = std::dynamic_pointer_cast<ov::op::v13::ScaledDotProductAttention>(pattern_map.at(sdpa_with_attn_mask_m).get_node_shared_ptr()); | ||
} else if (pattern_map.find(sdpa_with_attn_mask_and_scale_m) != pattern_map.end()) { | ||
sdpa = std::dynamic_pointer_cast<ov::op::v13::ScaledDotProductAttention>(pattern_map.at(sdpa_with_attn_mask_and_scale_m).get_node_shared_ptr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think code above can be replaced with m.get_match_root()
} else if (pattern_map.find(sdpa_with_attn_mask_and_scale_m) != pattern_map.end()) { | ||
auto attn_mask = sdpa->get_input_source_output(3); | ||
auto scale = sdpa->get_input_source_output(4); | ||
sdpa_new = std::make_shared<op::SDPA>(input_q, input_k, input_v, attn_mask, scale, order_q, order_k, order_v, order_output, sdpa->get_causal()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can probably have internal SDPA op with optional inputs to simplify the converters.
JitConstants SDPAKernelOpt::GetJitConstants(const sdpa_params& params, size_t kernel_idx) const { | ||
auto jit = SDPAKernelBase::GetJitConstants(params); | ||
|
||
const auto softmax_acc_dt = params.inputs[0].GetDType(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be great to try FP32 accumulator. If it doesn't impact perf, then we can use it to have better accuracy in some cases
} | ||
|
||
KernelsPriority SDPAKernelOpt::GetKernelsPriority(const Params& /*params*/) const { | ||
return FORCE_PRIORITY_1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should probably be > 1, at least for platforms with dpas
#elif INPUT1_DIMS == 6 | ||
return INPUT1_GET_INDEX_SAFE(b, f, w, z, y, x); | ||
#else | ||
# error sdpa_ref.cl : Unsupported input 1 format |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sdpa_opt
DO_BROADCAST_KEY_VALUE; | ||
#endif | ||
#if INPUT1_SIMPLE | ||
return GET_DATA_INDEX_6D_SAFE(INPUT1, b, f, w, z, y, x); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we use _SAFE version here? As I understand, DO_BROADCAST_KEY_VALUE divides index by group size internally, so do we have out of bound indices somehow? Same question for other functions
exp_sum[seq_idx] = sub_group_reduce_add(exp_sum[seq_idx]); | ||
} | ||
|
||
// const SOFTMAX_ACCUMULATOR_TYPE inv_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ONE / exp_sum[seq_idx]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not needed I think
Details: