diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp
index 89634574027..14b2bc694a6 100644
--- a/extension/llm/custom_ops/op_sdpa.cpp
+++ b/extension/llm/custom_ops/op_sdpa.cpp
@@ -382,7 +382,7 @@ void cpu_flash_attention(
       /* qk_sum */ qSplitSize +
       /* dst    */ qSplitSize * headSize;
 
-  int64_t size_bytes = size_per_thread * num_thread * query.element_size() * 4;
+  int64_t size_bytes = size_per_thread * num_thread * query.element_size();
   std::vector<char> buf_vec(size_bytes);
   void* buf = reinterpret_cast<void*>(buf_vec.data());
   // Need to double check the following