Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 87 additions & 3 deletions tests/test-flash-decoding-custom-op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,15 +292,99 @@ int main() {
size_t quant_nb2 = quant_len * quant_nb1;
size_t quant_nb3 = quant_nb2 * n_kv_heads;

size_t kv_quant_offset = n_kv_heads * fp16_window * fp16_nb1;
// Fix: calculate correct offset for token position fp16_window in the original tensor
// Since K tensor format is [head_dim, kv_len, n_kv_heads, 1], offset should be at token fp16_window
size_t kv_quant_offset = fp16_window * k->nb[1]; // Use tensor's actual stride for dimension 1

ggml_tensor * k_fp16 = ggml_view_4d(ctx, k, head_dim, fp16_window, n_kv_heads, 1, fp16_nb1, fp16_nb2, fp16_nb3, 0);
ggml_tensor * v_fp16 = ggml_view_4d(ctx, v, head_dim, fp16_window, n_kv_heads, 1, fp16_nb1, fp16_nb2, fp16_nb3, 0);

// Only create quantized views if we have quantized tokens
// NOTICE: This quant_len can be 0;
ggml_tensor * k_quant = ggml_view_4d(ctx, k, head_dim, quant_len, n_kv_heads, 1, quant_nb1, quant_nb2, quant_nb3, kv_quant_offset);
ggml_tensor * v_quant = ggml_view_4d(ctx, v, head_dim, quant_len, n_kv_heads, 1, quant_nb1, quant_nb2, quant_nb3, kv_quant_offset);
ggml_tensor * k_quant = nullptr;
ggml_tensor * v_quant = nullptr;

// Create Q4_0 quantized tensors for k_quant and v_quant if we have quantized tokens
if (quant_len > 0) {
printf("Creating simple Q4_0 quantized tensors for %zu tokens\n", quant_len);

// Calculate total elements for the quantized portion
size_t total_elements = head_dim * quant_len * n_kv_heads;

// Create simple 1D tensors for quantization (based on successful test_unified_cache_copy.cpp example)
ggml_tensor * k_quant_src = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, total_elements);
ggml_tensor * v_quant_src = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, total_elements);
k_quant = ggml_new_tensor_1d(ctx, GGML_TYPE_Q4_0, total_elements);
v_quant = ggml_new_tensor_1d(ctx, GGML_TYPE_Q4_0, total_elements);

printf("Created 1D tensors: src=%zu elements, dst=%zu elements\n",
total_elements, total_elements);
printf("K_src: %zu bytes, K_quant: %zu bytes\n",
ggml_nbytes(k_quant_src), ggml_nbytes(k_quant));

// Fill source tensors with data from the quantized portion (tokens fp16_window to fp16_window+quant_len)
ggml_fp16_t* k_src_data = (ggml_fp16_t*)k_quant_src->data;
ggml_fp16_t* v_src_data = (ggml_fp16_t*)v_quant_src->data;
ggml_fp16_t* k_orig_data = (ggml_fp16_t*)k->data;
ggml_fp16_t* v_orig_data = (ggml_fp16_t*)v->data;

// Copy data from the quantized portion to the 1D tensors
size_t idx = 0;
for (size_t h = 0; h < n_kv_heads; h++) {
for (size_t t = 0; t < quant_len; t++) {
for (size_t d = 0; d < head_dim; d++) {
// Source position: token (fp16_window + t) in original tensor
size_t orig_idx = d + (fp16_window + t) * head_dim + h * head_dim * GGML_PAD(kv_len, n_pad);

k_src_data[idx] = k_orig_data[orig_idx];
v_src_data[idx] = v_orig_data[orig_idx];
idx++;
}
}
}

printf("Data copy completed successfully\n");

// Use ggml_cpy to quantize the data from F16 to Q4_0 (based on successful example)
printf("Creating ggml_cpy operations...\n");
ggml_tensor * k_quantize_op = ggml_cpy(ctx, k_quant_src, k_quant);
ggml_tensor * v_quantize_op = ggml_cpy(ctx, v_quant_src, v_quant);

printf("ggml_cpy operations created successfully\n");

// Build quantization graph and execute it
printf("Building computation graph...\n");
struct ggml_cgraph * graph_quantize = ggml_new_graph(ctx);
ggml_build_forward_expand(graph_quantize, k_quantize_op);
ggml_build_forward_expand(graph_quantize, v_quantize_op);

printf("Computing quantization (F16 -> Q4_0)...\n");
enum ggml_status status_quantize = ggml_graph_compute_with_ctx(ctx, graph_quantize, n_threads);

if (status_quantize != GGML_STATUS_SUCCESS) {
printf("ERROR: Quantization failed with status: %d\n", status_quantize);
ggml_free(ctx);
return 1;
}

printf("Quantization completed successfully\n");

// Now we need to create 4D views of our 1D quantized tensors for the flash attention
// Reshape the 1D quantized tensors back to 4D for flash attention compatibility
printf("Creating 4D views for flash attention...\n");

// For flash attention, we need 4D tensors with the correct shape
// We can't use ggml_view_4d on quantized tensors directly due to size constraints
// Instead, we'll work with the 1D tensors and let the flash attention handle the reshape

printf("K_quant final shape: 1D tensor with %ld elements, type: %s\n",
k_quant->ne[0], ggml_type_name(k_quant->type));
printf("V_quant final shape: 1D tensor with %ld elements, type: %s\n",
v_quant->ne[0], ggml_type_name(v_quant->type));

} else {
printf("No quantized tokens to create (quant_len = 0)\n");
}

// ============================================================================
// Test 1: Custom F32 Flash-attention Implementation
Expand Down