Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 35 additions & 13 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3772,7 +3772,6 @@ static void ggml_compute_forward_out_prod_f32(
}
}
}

static void ggml_compute_forward_out_prod_q_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
Expand Down Expand Up @@ -4560,7 +4559,6 @@ void ggml_compute_forward_get_rows_back(
// exit(0);
//}
}

// ggml_compute_forward_diag

static void ggml_compute_forward_diag_f32(
Expand Down Expand Up @@ -5350,7 +5348,6 @@ static void ggml_compute_forward_rope_f32(
}
}
}

// TODO: deduplicate f16/f32 code
static void ggml_compute_forward_rope_f16(
const ggml_compute_params * params,
Expand Down Expand Up @@ -6142,7 +6139,6 @@ void ggml_compute_forward_conv_transpose_2d(
}
}
}

// ggml_compute_forward_conv_2d_dw

struct ggml_conv_2d_dw_params {
Expand Down Expand Up @@ -6929,7 +6925,6 @@ void ggml_compute_forward_argsort(
}
}
}

// ggml_compute_forward_flash_attn_ext

static void ggml_compute_forward_flash_attn_ext_f16(
Expand Down Expand Up @@ -7274,20 +7269,49 @@ static void ggml_compute_forward_flash_attn_ext_f16_with_state(
float S = state_data[state_idx * 2 + 1]; // sum (index 1)
float M = state_data[state_idx * 2 + 0]; // maximum KQ value (index 0)

// If this is the first call (indicated by M == -INFINITY), initialize properly
if (M == -INFINITY) {
S = 0.0f;
}
// Check if this is a continuation of previous segments
bool is_continuation = (M != -INFINITY && S > 0.0f);

float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
float * V32 = (VKQ32 + 1*DV); // (temporary) FP32 V buffer
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16

// Initialize VKQ accumulator - CRITICAL FIX: restore previous accumulated results
if (v->type == GGML_TYPE_F16) {
memset(VKQ16, 0, DV*sizeof(ggml_fp16_t));
if (is_continuation) {
// Load previous accumulated result from dst tensor and scale by previous sum S
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;
float * prev_result = (float *) ((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1);

// Scale previous result by S and convert to FP16
for (int64_t d = 0; d < DV; ++d) {
VKQ16[d] = GGML_FP32_TO_FP16(prev_result[d] * S);
}
} else {
memset(VKQ16, 0, DV*sizeof(ggml_fp16_t));
S = 0.0f;
M = -INFINITY;
}
} else {
memset(VKQ32, 0, DV*sizeof(float));
if (is_continuation) {
// Load previous accumulated result from dst tensor and scale by previous sum S
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;
float * prev_result = (float *) ((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1);

// Scale previous result by S
for (int64_t d = 0; d < DV; ++d) {
VKQ32[d] = prev_result[d] * S;
}
} else {
memset(VKQ32, 0, DV*sizeof(float));
S = 0.0f;
M = -INFINITY;
}
}

const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
Expand Down Expand Up @@ -8392,7 +8416,6 @@ void ggml_compute_forward_win_unpart(
}
}
}

//gmml_compute_forward_unary

void ggml_compute_forward_unary(
Expand Down Expand Up @@ -9193,7 +9216,6 @@ void ggml_compute_forward_map_custom2(

p.fun(dst, a, b, params->ith, params->nth, p.userdata);
}

// ggml_compute_forward_map_custom3

void ggml_compute_forward_map_custom3(
Expand Down
19 changes: 12 additions & 7 deletions tests/test-flash-attn-state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include <iostream>

// Use fixed seed for reproducible results
static std::mt19937 g_rng(42);
static std::mt19937 g_rng(std::random_device{}());

static void fill_tensor_f32(ggml_tensor * dst, float min_val = -1.0f, float max_val = 1.0f) {
float* data = (float*)dst->data;
Expand Down Expand Up @@ -290,6 +290,7 @@ int main() {
print_tensor_info(" V segment", v_segment);

// Compute flash attention with state for this segment
// CRITICAL: Create the operation but redirect its output to our accumulation tensor
ggml_tensor * result_seg = ggml_flash_attn_ext_with_state(
ctx, q, k_segment, v_segment, mask_segment, state,
1.0f / std::sqrt(head_dim), // scale
Expand All @@ -304,6 +305,14 @@ int main() {
return 1;
}

// CRITICAL FIX: Redirect the operation's output to our accumulation tensor
// This ensures that each segment reads from and writes to the same tensor
result_seg->data = result_segmented->data;
result_seg->nb[0] = result_segmented->nb[0];
result_seg->nb[1] = result_segmented->nb[1];
result_seg->nb[2] = result_segmented->nb[2];
result_seg->nb[3] = result_segmented->nb[3];

struct ggml_cgraph * graph_seg = ggml_new_graph(ctx);
ggml_build_forward_expand(graph_seg, result_seg);

Expand All @@ -316,7 +325,7 @@ int main() {
}

printf(" Segment %d computed successfully\n", seg + 1);
print_f32_sample(" Segment result", result_seg, 6);
print_f32_sample(" Segment result", result_segmented, 6);

// Print state after this segment
printf(" State after segment %d: ", seg + 1);
Expand All @@ -325,11 +334,7 @@ int main() {
}
printf("...\n");

// For the final segment, copy the result (this contains the accumulated result of all segments)
if (seg == kv_segments - 1) {
memcpy(result_segmented->data, result_seg->data, ggml_nbytes(result_seg));
printf(" Final accumulated result copied from segment %d\n", seg + 1);
}
// No need to copy result since we're already writing to result_segmented
}

printf("\nSegmented computation completed\n");
Expand Down