Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] Add initial SDPA implementation #24466

Merged
merged 4 commits into from May 22, 2024

Conversation

sshlyapn
Copy link
Contributor

@sshlyapn sshlyapn commented May 10, 2024

Details:

  • Add initial SDPA implementation, input transpose fusions and GQA related optimization (broadcast fusion)

@sshlyapn sshlyapn added the category: GPU OpenVINO GPU plugin label May 10, 2024
@sshlyapn sshlyapn added this to the 2024.2 milestone May 10, 2024
@sshlyapn sshlyapn requested review from a team as code owners May 10, 2024 13:18
@sshlyapn sshlyapn force-pushed the sdpa_impl branch 3 times, most recently from 75c4dd0 to 280cc81 Compare May 10, 2024 17:35
Copy link
Contributor

@rkazants rkazants left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@itikhono, @sshlyapn,
do we expect that https://www.tensorflow.org/api_docs/python/tf/keras/layers/Attention that corresponds to SDPA will be fused into SDPA internal operation on CPU and GPU?

@sshlyapn
Copy link
Contributor Author

@itikhono, @sshlyapn, do we expect that https://www.tensorflow.org/api_docs/python/tf/keras/layers/Attention that corresponds to SDPA will be fused into SDPA internal operation on CPU and GPU?

@rkazants I think it makes sense, but currently not all the cases are well optimized, so if it will have an ability to decompose it back to current representation for some cases, sounds good

@sshlyapn sshlyapn force-pushed the sdpa_impl branch 2 times, most recently from c887523 to 63fd7cc Compare May 21, 2024 06:12
@sshlyapn sshlyapn requested a review from a team as a code owner May 22, 2024 07:35
@github-actions github-actions bot added category: inference OpenVINO Runtime library - Inference category: CPP API OpenVINO CPP API bindings labels May 22, 2024
* @brief Turning on this key disables SDPA operation decomposition and keeps SDPA operation in the graph.
* Enabling SDPA optimization may provide performance improvements and memory usage reduction.
* This key serves as a recommendation and may be ignored in known sub-optimal cases.
* @ingroup ov_runtime_ocl_gpu_prop_cpp_api
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the default value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ilya-lavrenov currently it's disabled by default. However, in the final version, it will depend on whether support for indirect inputs is implemented for SDPA in time or not

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows to switch on for models where indirect inputs are not required

Copy link
Contributor

@vladimir-paramuzov vladimir-paramuzov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No critical issues found so far. Feel free to follow up the comments in the next PR

return params;
}

static std::unique_ptr<primitive_impl> create(const typed_program_node<scaled_dot_product_attention>& arg, const kernel_impl_params& impl_param) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this impl have any difference with common version in base class?

CeilDiv(target_seq_len, target_seq_len_block_size),
head_size * num_of_partitions };
dispatch_data.lws = { 1, 1, head_size };
} else if (kernel_idx == 2) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably it should be kernel_idx == KernelsTypes::FINALIZATION

kernel_impl_params const& impl_param) {
auto desc = impl_param.typed_desc<scaled_dot_product_attention>();

return impl_param.get_input_layout(0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this impl is incorrect, so maybe just throw unimplemented exception here intead?

} else if (pattern_map.find(sdpa_with_attn_mask_m) != pattern_map.end()) {
sdpa = std::dynamic_pointer_cast<ov::op::v13::ScaledDotProductAttention>(pattern_map.at(sdpa_with_attn_mask_m).get_node_shared_ptr());
} else if (pattern_map.find(sdpa_with_attn_mask_and_scale_m) != pattern_map.end()) {
sdpa = std::dynamic_pointer_cast<ov::op::v13::ScaledDotProductAttention>(pattern_map.at(sdpa_with_attn_mask_and_scale_m).get_node_shared_ptr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think code above can be replaced with m.get_match_root()

} else if (pattern_map.find(sdpa_with_attn_mask_and_scale_m) != pattern_map.end()) {
auto attn_mask = sdpa->get_input_source_output(3);
auto scale = sdpa->get_input_source_output(4);
sdpa_new = std::make_shared<op::SDPA>(input_q, input_k, input_v, attn_mask, scale, order_q, order_k, order_v, order_output, sdpa->get_causal());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably have internal SDPA op with optional inputs to simplify the converters.

JitConstants SDPAKernelOpt::GetJitConstants(const sdpa_params& params, size_t kernel_idx) const {
auto jit = SDPAKernelBase::GetJitConstants(params);

const auto softmax_acc_dt = params.inputs[0].GetDType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be great to try FP32 accumulator. If it doesn't impact perf, then we can use it to have better accuracy in some cases

}

KernelsPriority SDPAKernelOpt::GetKernelsPriority(const Params& /*params*/) const {
return FORCE_PRIORITY_1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably be > 1, at least for platforms with dpas

#elif INPUT1_DIMS == 6
return INPUT1_GET_INDEX_SAFE(b, f, w, z, y, x);
#else
# error sdpa_ref.cl : Unsupported input 1 format
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sdpa_opt

DO_BROADCAST_KEY_VALUE;
#endif
#if INPUT1_SIMPLE
return GET_DATA_INDEX_6D_SAFE(INPUT1, b, f, w, z, y, x);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use _SAFE version here? As I understand, DO_BROADCAST_KEY_VALUE divides index by group size internally, so do we have out of bound indices somehow? Same question for other functions

exp_sum[seq_idx] = sub_group_reduce_add(exp_sum[seq_idx]);
}

// const SOFTMAX_ACCUMULATOR_TYPE inv_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ONE / exp_sum[seq_idx];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed I think

@p-durandin p-durandin added this pull request to the merge queue May 22, 2024
Merged via the queue into openvinotoolkit:master with commit 1e5f025 May 22, 2024
101 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
category: CPP API OpenVINO CPP API bindings category: GPU OpenVINO GPU plugin category: inference OpenVINO Runtime library - Inference Code Freeze priority: high High piority
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants