diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index 89634574027..14b2bc694a6 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -382,7 +382,7 @@ void cpu_flash_attention( /* qk_sum */ qSplitSize + /* dst */ qSplitSize * headSize; - int64_t size_bytes = size_per_thread * num_thread * query.element_size() * 4; + int64_t size_bytes = size_per_thread * num_thread * query.element_size(); std::vector<char> buf_vec(size_bytes); void* buf = reinterpret_cast<void*>(buf_vec.data()); // Need to double check the following