diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index a614b2001bf64..c98d7c48e49df 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3772,7 +3772,6 @@ static void ggml_compute_forward_out_prod_f32( } } } - static void ggml_compute_forward_out_prod_q_f32( const ggml_compute_params * params, ggml_tensor * dst) { @@ -4560,7 +4559,6 @@ void ggml_compute_forward_get_rows_back( // exit(0); //} } - // ggml_compute_forward_diag static void ggml_compute_forward_diag_f32( @@ -5350,7 +5348,6 @@ static void ggml_compute_forward_rope_f32( } } } - // TODO: deduplicate f16/f32 code static void ggml_compute_forward_rope_f16( const ggml_compute_params * params, @@ -6142,7 +6139,6 @@ void ggml_compute_forward_conv_transpose_2d( } } } - // ggml_compute_forward_conv_2d_dw struct ggml_conv_2d_dw_params { @@ -6929,7 +6925,6 @@ void ggml_compute_forward_argsort( } } } - // ggml_compute_forward_flash_attn_ext static void ggml_compute_forward_flash_attn_ext_f16( @@ -7274,20 +7269,49 @@ static void ggml_compute_forward_flash_attn_ext_f16_with_state( float S = state_data[state_idx * 2 + 1]; // sum (index 1) float M = state_data[state_idx * 2 + 0]; // maximum KQ value (index 0) - // If this is the first call (indicated by M == -INFINITY), initialize properly - if (M == -INFINITY) { - S = 0.0f; - } + // Check if this is a continuation of previous segments + bool is_continuation = (M != -INFINITY && S > 0.0f); float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator float * V32 = (VKQ32 + 1*DV); // (temporary) FP32 V buffer ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16 + // Initialize VKQ accumulator - CRITICAL FIX: restore previous accumulated results if (v->type == GGML_TYPE_F16) { - memset(VKQ16, 0, DV*sizeof(ggml_fp16_t)); + if (is_continuation) { + // Load previous accumulated result from dst tensor and scale by previous sum S + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + float * prev_result = (float *) ((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1); + + // Scale previous result by S and convert to FP16 + for (int64_t d = 0; d < DV; ++d) { + VKQ16[d] = GGML_FP32_TO_FP16(prev_result[d] * S); + } + } else { + memset(VKQ16, 0, DV*sizeof(ggml_fp16_t)); + S = 0.0f; + M = -INFINITY; + } } else { - memset(VKQ32, 0, DV*sizeof(float)); + if (is_continuation) { + // Load previous accumulated result from dst tensor and scale by previous sum S + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + float * prev_result = (float *) ((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1); + + // Scale previous result by S + for (int64_t d = 0; d < DV; ++d) { + VKQ32[d] = prev_result[d] * S; + } + } else { + memset(VKQ32, 0, DV*sizeof(float)); + S = 0.0f; + M = -INFINITY; + } } const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; @@ -8392,7 +8416,6 @@ void ggml_compute_forward_win_unpart( } } } - //gmml_compute_forward_unary void ggml_compute_forward_unary( @@ -9193,7 +9216,6 @@ void ggml_compute_forward_map_custom2( p.fun(dst, a, b, params->ith, params->nth, p.userdata); } - // ggml_compute_forward_map_custom3 void ggml_compute_forward_map_custom3( diff --git a/tests/test-flash-attn-state.cpp b/tests/test-flash-attn-state.cpp index 7d1be7f02551f..2c9144f6e899b 100644 --- a/tests/test-flash-attn-state.cpp +++ b/tests/test-flash-attn-state.cpp @@ -14,7 +14,7 @@ #include // Use fixed seed for reproducible results -static std::mt19937 g_rng(42); +static std::mt19937 g_rng(std::random_device{}()); static void fill_tensor_f32(ggml_tensor * dst, float min_val = -1.0f, float max_val = 1.0f) { float* data = (float*)dst->data; @@ -290,6 +290,7 @@ int main() { print_tensor_info(" V segment", v_segment); // Compute flash attention with state for this segment + // CRITICAL: Create the operation but redirect its output to our accumulation tensor ggml_tensor * result_seg = ggml_flash_attn_ext_with_state( ctx, q, k_segment, v_segment, mask_segment, state, 1.0f / std::sqrt(head_dim), // scale @@ -304,6 +305,14 @@ int main() { return 1; } + // CRITICAL FIX: Redirect the operation's output to our accumulation tensor + // This ensures that each segment reads from and writes to the same tensor + result_seg->data = result_segmented->data; + result_seg->nb[0] = result_segmented->nb[0]; + result_seg->nb[1] = result_segmented->nb[1]; + result_seg->nb[2] = result_segmented->nb[2]; + result_seg->nb[3] = result_segmented->nb[3]; + struct ggml_cgraph * graph_seg = ggml_new_graph(ctx); ggml_build_forward_expand(graph_seg, result_seg); @@ -316,7 +325,7 @@ int main() { } printf(" Segment %d computed successfully\n", seg + 1); - print_f32_sample(" Segment result", result_seg, 6); + print_f32_sample(" Segment result", result_segmented, 6); // Print state after this segment printf(" State after segment %d: ", seg + 1); @@ -325,11 +334,7 @@ int main() { } printf("...\n"); - // For the final segment, copy the result (this contains the accumulated result of all segments) - if (seg == kv_segments - 1) { - memcpy(result_segmented->data, result_seg->data, ggml_nbytes(result_seg)); - printf(" Final accumulated result copied from segment %d\n", seg + 1); - } + // No need to copy result since we're already writing to result_segmented } printf("\nSegmented computation completed\n");