From 71cc797b0e94c25cbdbec315f5f871b7ac793b1e Mon Sep 17 00:00:00 2001 From: Zijie Tian <1049154785@qq.com> Date: Tue, 24 Jun 2025 06:03:11 +0800 Subject: [PATCH 1/2] Add PyTorch comparison to flash attention state test --- tests/test-flash-attn-state.cpp | 384 ++++++++++++++++++++++---------- 1 file changed, 271 insertions(+), 113 deletions(-) diff --git a/tests/test-flash-attn-state.cpp b/tests/test-flash-attn-state.cpp index 2c9144f6e899b..cee710f03dd17 100644 --- a/tests/test-flash-attn-state.cpp +++ b/tests/test-flash-attn-state.cpp @@ -1,24 +1,28 @@ -#include "ggml.h" -#include "ggml-cpu.h" #include "../ggml/src/ggml-impl.h" +#include "ggml-cpu.h" +#include "ggml.h" + +#ifdef LLAMA_TORCH_AVAILABLE +# include +#endif +#include #include #include #include #include -#include -#include -#include #include -#include +#include #include +#include +#include // Use fixed seed for reproducible results static std::mt19937 g_rng(std::random_device{}()); static void fill_tensor_f32(ggml_tensor * dst, float min_val = -1.0f, float max_val = 1.0f) { - float* data = (float*)dst->data; - size_t n_elements = ggml_nelements(dst); + float * data = (float *) dst->data; + size_t n_elements = ggml_nelements(dst); std::uniform_real_distribution dis(min_val, max_val); for (size_t i = 0; i < n_elements; i++) { @@ -27,8 +31,8 @@ static void fill_tensor_f32(ggml_tensor * dst, float min_val = -1.0f, float max_ } static void fill_tensor_f16(ggml_tensor * dst, float min_val = -1.0f, float max_val = 1.0f) { - ggml_fp16_t* data = (ggml_fp16_t*)dst->data; - size_t n_elements = ggml_nelements(dst); + ggml_fp16_t * data = (ggml_fp16_t *) dst->data; + size_t n_elements = ggml_nelements(dst); std::uniform_real_distribution dis(min_val, max_val); for (size_t i = 0; i < n_elements; i++) { @@ -36,22 +40,57 @@ static void fill_tensor_f16(ggml_tensor * dst, float min_val = -1.0f, float max_ } } -static void print_tensor_info(const char* name, ggml_tensor* tensor) { - printf("%s: [%ld, %ld, %ld, %ld] type=%s, elements=%ld\n", - name, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], - ggml_type_name(tensor->type), ggml_nelements(tensor)); +#ifdef LLAMA_TORCH_AVAILABLE +// Convert ggml tensor to a torch::Tensor (float32) +static torch::Tensor ggml_to_torch(ggml_tensor * tensor) { + auto tt = ggml_get_type_traits(tensor->type); + size_t n = ggml_nelements(tensor); + std::vector data(n); + if (tensor->type == GGML_TYPE_F32) { + memcpy(data.data(), tensor->data, n * sizeof(float)); + } else if (tt->to_float) { + tt->to_float(tensor->data, data.data(), n); + } else { + printf("Unsupported tensor type for torch conversion: %s\n", ggml_type_name(tensor->type)); + return {}; + } + + std::vector sizes; + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + if (tensor->ne[i] > 1 || i == 0) { + sizes.push_back(tensor->ne[i]); + } + } + return torch::from_blob(data.data(), sizes, torch::kFloat32).clone(); } -static void print_f32_sample(const char* name, ggml_tensor* tensor, int max_elements = 10) { +// Simple flash attention using PyTorch for verification +static torch::Tensor torch_flash_attention(torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor mask, + float scale) { + auto scores = torch::matmul(Q, K.transpose(-2, -1)) * scale; + if (mask.defined()) { + scores = scores + mask; + } + auto attn = torch::softmax(scores, -1); + return torch::matmul(attn, V); +} +#endif // LLAMA_TORCH_AVAILABLE + +static void print_tensor_info(const char * name, ggml_tensor * tensor) { + printf("%s: [%ld, %ld, %ld, %ld] type=%s, elements=%ld\n", name, tensor->ne[0], tensor->ne[1], tensor->ne[2], + tensor->ne[3], ggml_type_name(tensor->type), ggml_nelements(tensor)); +} + +static void print_f32_sample(const char * name, ggml_tensor * tensor, int max_elements = 10) { if (tensor->type != GGML_TYPE_F32) { printf("%s: Not F32 tensor (type=%s)\n", name, ggml_type_name(tensor->type)); return; } - - float* data = (float*)tensor->data; - size_t n_elements = ggml_nelements(tensor); - size_t elements_to_print = std::min((size_t)max_elements, n_elements); - + + float * data = (float *) tensor->data; + size_t n_elements = ggml_nelements(tensor); + size_t elements_to_print = std::min((size_t) max_elements, n_elements); + printf("%s sample values: ", name); for (size_t i = 0; i < elements_to_print; i++) { printf("%.6f ", data[i]); @@ -62,34 +101,34 @@ static void print_f32_sample(const char* name, ggml_tensor* tensor, int max_elem printf("\n"); } -static float tensor_max_diff(ggml_tensor* a, ggml_tensor* b) { +static float tensor_max_diff(ggml_tensor * a, ggml_tensor * b) { if (ggml_nelements(a) != ggml_nelements(b) || a->type != b->type) { printf("ERROR: Tensors have different sizes or types\n"); return -1.0f; } - + if (a->type != GGML_TYPE_F32) { printf("ERROR: Only F32 tensors supported for comparison\n"); return -1.0f; } - - float* data_a = (float*)a->data; - float* data_b = (float*)b->data; - size_t n_elements = ggml_nelements(a); - + + float * data_a = (float *) a->data; + float * data_b = (float *) b->data; + size_t n_elements = ggml_nelements(a); + float max_diff = 0.0f; for (size_t i = 0; i < n_elements; i++) { float diff = std::abs(data_a[i] - data_b[i]); - max_diff = std::max(max_diff, diff); + max_diff = std::max(max_diff, diff); } - + return max_diff; } -static void reset_state_tensor(ggml_tensor* state) { - float* state_data = (float*)state->data; - size_t n_pairs = ggml_nelements(state) / 2; - +static void reset_state_tensor(ggml_tensor * state) { + float * state_data = (float *) state->data; + size_t n_pairs = ggml_nelements(state) / 2; + for (size_t i = 0; i < n_pairs; i++) { state_data[i * 2 + 0] = -INFINITY; // M (max KQ value) state_data[i * 2 + 1] = 0.0f; // S (sum) @@ -100,13 +139,13 @@ int main() { printf("=== Flash Attention State Tensor - Comprehensive Test ===\n"); // Test parameters - const int head_dim = 32; - const int n_heads = 8; - const int n_kv_heads = 4; - const int seq_len = 2; - const int kv_len = 4; // Will be split into segments - const int n_threads = 4; - const int kv_segments = 2; // Split KV into 2 segments + const int head_dim = 32; + const int n_heads = 8; + const int n_kv_heads = 4; + const int seq_len = 2; + const int kv_len = 4; // Will be split into segments + const int n_threads = 4; + const int kv_segments = 2; // Split KV into 2 segments const int kv_segment_len = kv_len / kv_segments; printf("Test Parameters:\n"); @@ -115,11 +154,11 @@ int main() { printf(" kv_segments=%d, kv_segment_len=%d\n", kv_segments, kv_segment_len); // Initialize ggml context - const size_t ctx_size = 1024*1024*1024; // 1GB - struct ggml_init_params params = { - /*.mem_size =*/ ctx_size, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ false, + const size_t ctx_size = 1024 * 1024 * 1024; // 1GB + struct ggml_init_params params = { + /*.mem_size =*/ctx_size, + /*.mem_buffer =*/NULL, + /*.no_alloc =*/false, }; struct ggml_context * ctx = ggml_init(params); @@ -141,9 +180,9 @@ int main() { ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, head_dim, kv_len, n_kv_heads, 1); // Create mask tensor with proper padding - const int padded_kv_len = GGML_PAD(kv_len, 64); - const int padded_seq_len = GGML_PAD(seq_len, GGML_KQ_MASK_PAD); - ggml_tensor * mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, padded_kv_len, padded_seq_len); + const int padded_kv_len = GGML_PAD(kv_len, 64); + const int padded_seq_len = GGML_PAD(seq_len, GGML_KQ_MASK_PAD); + ggml_tensor * mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, padded_kv_len, padded_seq_len); // Create state tensor: [2, n_heads * seq_len] for [M, S] pairs ggml_tensor * state = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2, n_heads * seq_len); @@ -161,7 +200,7 @@ int main() { fill_tensor_f16(v, -0.7f, 0.7f); // Initialize mask (no causal mask - all positions can see all KV) - ggml_fp16_t* mask_data = (ggml_fp16_t*)mask->data; + ggml_fp16_t * mask_data = (ggml_fp16_t *) mask->data; memset(mask_data, 0, ggml_nbytes(mask)); for (int i = 0; i < seq_len; i++) { for (int j = 0; j < kv_len; j++) { @@ -177,11 +216,10 @@ int main() { // ============================================================================ printf("\n--- Test 1: Standard Flash Attention (Reference) ---\n"); - ggml_tensor * result_standard = ggml_flash_attn_ext( - ctx, q, k, v, mask, - 1.0f / std::sqrt(head_dim), // scale - 0.0f, // max_bias - 0.0f // logit_softcap + ggml_tensor * result_standard = ggml_flash_attn_ext(ctx, q, k, v, mask, + 1.0f / std::sqrt(head_dim), // scale + 0.0f, // max_bias + 0.0f // logit_softcap ); ggml_flash_attn_ext_set_prec(result_standard, GGML_PREC_F32); @@ -213,10 +251,9 @@ int main() { // Reset state tensor reset_state_tensor(state); - + // Create result tensor for accumulation (same shape as standard result) - ggml_tensor * result_segmented = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, - head_dim, seq_len, n_heads, 1); + ggml_tensor * result_segmented = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_dim, seq_len, n_heads, 1); // Initialize segmented result to zero memset(result_segmented->data, 0, ggml_nbytes(result_segmented)); @@ -224,40 +261,37 @@ int main() { printf("Processing %d segments of KV cache (segment_len=%d)...\n", kv_segments, kv_segment_len); for (int seg = 0; seg < kv_segments; seg++) { - printf("\n Segment %d/%d (kv_pos %d-%d):\n", - seg + 1, kv_segments, seg * kv_segment_len, (seg + 1) * kv_segment_len - 1); + printf("\n Segment %d/%d (kv_pos %d-%d):\n", seg + 1, kv_segments, seg * kv_segment_len, + (seg + 1) * kv_segment_len - 1); // Print state before this segment printf(" State before segment %d: ", seg + 1); - float* state_data = (float*)state->data; + float * state_data = (float *) state->data; for (int i = 0; i < std::min(4, n_heads * seq_len); i++) { printf("[M=%.3f,S=%.3f] ", state_data[i * 2 + 0], state_data[i * 2 + 1]); } printf("...\n"); // Create views of K and V for this segment using ggml_view_4d - ggml_tensor * k_segment = ggml_view_4d(ctx, k, - head_dim, kv_segment_len, n_kv_heads, 1, // ne - k->nb[1], k->nb[2], k->nb[3], // nb (strides) - seg * kv_segment_len * k->nb[1]); // offset + ggml_tensor * k_segment = ggml_view_4d(ctx, k, head_dim, kv_segment_len, n_kv_heads, 1, // ne + k->nb[1], k->nb[2], k->nb[3], // nb (strides) + seg * kv_segment_len * k->nb[1]); // offset - ggml_tensor * v_segment = ggml_view_4d(ctx, v, - head_dim, kv_segment_len, n_kv_heads, 1, // ne - v->nb[1], v->nb[2], v->nb[3], // nb (strides) - seg * kv_segment_len * v->nb[1]); // offset + ggml_tensor * v_segment = ggml_view_4d(ctx, v, head_dim, kv_segment_len, n_kv_heads, 1, // ne + v->nb[1], v->nb[2], v->nb[3], // nb (strides) + seg * kv_segment_len * v->nb[1]); // offset // Create mask for this segment - const int padded_segment_len = GGML_PAD(kv_segment_len, 64); - ggml_tensor * mask_segment = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, - padded_segment_len, padded_seq_len); + const int padded_segment_len = GGML_PAD(kv_segment_len, 64); + ggml_tensor * mask_segment = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, padded_segment_len, padded_seq_len); // Fill segment mask - ggml_fp16_t* mask_seg_data = (ggml_fp16_t*)mask_segment->data; + ggml_fp16_t * mask_seg_data = (ggml_fp16_t *) mask_segment->data; memset(mask_seg_data, 0, ggml_nbytes(mask_segment)); - + for (int i = 0; i < seq_len; i++) { for (int j = 0; j < kv_segment_len; j++) { - int global_j = seg * kv_segment_len + j; + int global_j = seg * kv_segment_len + j; // No masking for segment - all positions can see all KV tokens in this segment mask_seg_data[i * padded_segment_len + j] = ggml_fp32_to_fp16(0.0f); } @@ -274,7 +308,7 @@ int main() { } printf("...\n"); } - + printf(" Debug - Segment mask (first 4 seq positions, all segment positions):\n"); for (int i = 0; i < std::min(4, seq_len); i++) { printf(" seq[%d]: ", i); @@ -291,11 +325,10 @@ int main() { // Compute flash attention with state for this segment // CRITICAL: Create the operation but redirect its output to our accumulation tensor - ggml_tensor * result_seg = ggml_flash_attn_ext_with_state( - ctx, q, k_segment, v_segment, mask_segment, state, - 1.0f / std::sqrt(head_dim), // scale - 0.0f, // max_bias - 0.0f // logit_softcap + ggml_tensor * result_seg = ggml_flash_attn_ext_with_state(ctx, q, k_segment, v_segment, mask_segment, state, + 1.0f / std::sqrt(head_dim), // scale + 0.0f, // max_bias + 0.0f // logit_softcap ); ggml_flash_attn_ext_set_prec(result_seg, GGML_PREC_F32); @@ -307,7 +340,7 @@ int main() { // CRITICAL FIX: Redirect the operation's output to our accumulation tensor // This ensures that each segment reads from and writes to the same tensor - result_seg->data = result_segmented->data; + result_seg->data = result_segmented->data; result_seg->nb[0] = result_segmented->nb[0]; result_seg->nb[1] = result_segmented->nb[1]; result_seg->nb[2] = result_segmented->nb[2]; @@ -340,27 +373,149 @@ int main() { printf("\nSegmented computation completed\n"); print_f32_sample("Final segmented result", result_segmented, 8); + // ===================================================================== + // Test 3: PyTorch Verification using scaled_dot_product_attention + // ===================================================================== + printf("\n--- PyTorch Verification ---\n"); + + std::vector torch_result_data; + bool torch_success = false; + +#ifdef LLAMA_TORCH_AVAILABLE + try { + auto options = torch::TensorOptions().dtype(torch::kFloat32); + + auto q_t = torch::zeros({ 1, n_heads, seq_len, head_dim }, options); + auto k_t = torch::zeros({ 1, n_kv_heads, kv_len, head_dim }, options); + auto v_t = torch::zeros({ 1, n_kv_heads, kv_len, head_dim }, options); + + float * qd = q_t.data_ptr(); + float * kd = k_t.data_ptr(); + float * vd = v_t.data_ptr(); + + for (int h = 0; h < n_heads; ++h) { + for (int s = 0; s < seq_len; ++s) { + for (int d = 0; d < head_dim; ++d) { + int gi = d + s * head_dim + h * head_dim * seq_len; + int ti = h * seq_len * head_dim + s * head_dim + d; + qd[ti] = ((float *) q->data)[gi]; + } + } + } + + for (int h = 0; h < n_kv_heads; ++h) { + for (int s = 0; s < kv_len; ++s) { + for (int d = 0; d < head_dim; ++d) { + int gi = d + s * head_dim + h * head_dim * kv_len; + int ti = h * kv_len * head_dim + s * head_dim + d; + kd[ti] = ggml_fp16_to_fp32(((ggml_fp16_t *) k->data)[gi]); + vd[ti] = ggml_fp16_to_fp32(((ggml_fp16_t *) v->data)[gi]); + } + } + } + + auto mask_t = torch::ones({ 1, n_heads, seq_len, kv_len }, torch::TensorOptions().dtype(torch::kBool)); + bool * mask_td = mask_t.data_ptr(); + ggml_fp16_t * mask_d = (ggml_fp16_t *) mask->data; + + for (int h = 0; h < n_heads; ++h) { + for (int s = 0; s < seq_len; ++s) { + for (int d = 0; d < kv_len; ++d) { + int gi = d + s * padded_kv_len; + float val = GGML_FP16_TO_FP32(mask_d[gi]); + int ti = h * seq_len * kv_len + s * kv_len + d; + mask_td[ti] = (val == 0.0f); + } + } + } + + if (n_heads > n_kv_heads) { + k_t = k_t.repeat_interleave(n_heads / n_kv_heads, 1); + v_t = v_t.repeat_interleave(n_heads / n_kv_heads, 1); + } + + float scale = 1.0f / std::sqrt((float) head_dim); + auto torch_res = torch::scaled_dot_product_attention(q_t, k_t, v_t, mask_t, 0.0, false, scale); + torch_res = torch_res.permute({ 0, 2, 1, 3 }).contiguous(); + + float * trd = torch_res.data_ptr(); + size_t numel = torch_res.numel(); + torch_result_data.resize(numel); + for (int h = 0; h < n_heads; ++h) { + for (int s = 0; s < seq_len; ++s) { + for (int d = 0; d < head_dim; ++d) { + int ti = h * seq_len * head_dim + s * head_dim + d; + int ci = d + s * head_dim + h * head_dim * seq_len; + torch_result_data[ci] = trd[ti]; + } + } + } + torch_success = true; + printf("PyTorch computation successful\n"); + } catch (const std::exception & e) { + printf("PyTorch verification failed: %s\n", e.what()); + torch_success = false; + } +#else + printf("PyTorch verification skipped (PyTorch not available)\n"); +#endif + // ============================================================================ // Test 3: Compare Results // ============================================================================ - printf("\n--- Test 3: Comparing Results ---\n"); - - float max_diff = tensor_max_diff(result_standard, result_segmented); - - printf("Comparison between standard and segmented results:\n"); - printf(" Maximum absolute difference: %.2e\n", max_diff); - - const float tolerance = 1e-4; // Reasonable tolerance for F16/F32 precision - - if (max_diff < tolerance) { - printf(" ✅ PASS: Results match within tolerance (%.2e)\n", tolerance); + printf("\n--- Unified Results Comparison ---\n"); + + float * standard_data = (float *) result_standard->data; + float * segmented_data = (float *) result_segmented->data; + size_t n_elems = ggml_nelements(result_standard); + if (torch_success) { + n_elems = std::min(n_elems, torch_result_data.size()); + } + + float max_std_seg = 0.0f, max_std_torch = 0.0f, max_seg_torch = 0.0f; + for (size_t i = 0; i < n_elems; ++i) { + float s = standard_data[i]; + float g = segmented_data[i]; + max_std_seg = std::max(max_std_seg, std::abs(s - g)); + if (torch_success) { + float t = torch_result_data[i]; + max_std_torch = std::max(max_std_torch, std::abs(s - t)); + max_seg_torch = std::max(max_seg_torch, std::abs(g - t)); + } + } + + printf("Max diff standard vs segmented : %.6e\n", max_std_seg); + if (torch_success) { + printf("Max diff standard vs torch : %.6e\n", max_std_torch); + printf("Max diff segmented vs torch : %.6e\n", max_seg_torch); + } + + printf("\nDetailed Comparison Table (first 128 elements):\n"); + if (torch_success) { + printf("Idx | Standard | Segmented | Torch | S-G Diff | S-T Diff | G-T Diff\n"); + printf("----|-------------|-------------|-------------|-----------|-----------|-----------\n"); } else { - printf(" ❌ FAIL: Results differ beyond tolerance (%.2e)\n", tolerance); - - // Print detailed comparison for debugging - printf("\nDetailed comparison:\n"); - print_f32_sample("Standard", result_standard, 20); - print_f32_sample("Segmented", result_segmented, 20); + printf("Idx | Standard | Segmented | S-G Diff\n"); + printf("----|-------------|-------------|-----------\n"); + } + + size_t show = std::min((size_t) 128, n_elems); + for (size_t i = 0; i < show; ++i) { + float s = standard_data[i]; + float g = segmented_data[i]; + if (torch_success) { + float t = torch_result_data[i]; + printf("%3zu | %11.6f | %11.6f | %11.6f | %.6e | %.6e | %.6e\n", i, s, g, t, std::abs(s - g), + std::abs(s - t), std::abs(g - t)); + } else { + printf("%3zu | %11.6f | %11.6f | %.6e\n", i, s, g, std::abs(s - g)); + } + } + + const float tolerance = 1e-3f; + bool pass = max_std_seg < tolerance; + if (torch_success) { + pass = pass && max_std_torch < tolerance && max_seg_torch < tolerance; } // ============================================================================ @@ -371,19 +526,19 @@ int main() { printf("Final state tensor values:\n"); print_f32_sample("Final state", state, 16); - float* state_data = (float*)state->data; - float min_m = INFINITY, max_m = -INFINITY; - float min_s = INFINITY, max_s = -INFINITY; - + float * state_data = (float *) state->data; + float min_m = INFINITY, max_m = -INFINITY; + float min_s = INFINITY, max_s = -INFINITY; + for (int i = 0; i < n_heads * seq_len; i++) { float m_val = state_data[i * 2 + 0]; float s_val = state_data[i * 2 + 1]; - + if (m_val != -INFINITY) { min_m = std::min(min_m, m_val); max_m = std::max(max_m, m_val); } - + min_s = std::min(min_s, s_val); max_s = std::max(max_s, s_val); } @@ -396,21 +551,24 @@ int main() { // Final Results // ============================================================================ printf("\n=== Final Test Results ===\n"); - - if (max_diff < tolerance) { + + if (pass) { printf("🎉 ALL TESTS PASSED!\n"); printf("✅ Segmented flash attention with state produces identical results\n"); - printf("✅ State tensor correctly accumulates across segments\n"); - printf("✅ Implementation is working correctly\n"); + if (torch_success) { + printf("✅ PyTorch results match GGML outputs\n"); + } } else { printf("❌ TESTS FAILED!\n"); - printf("❌ Results differ beyond acceptable tolerance\n"); - printf("❌ Implementation needs debugging\n"); } - printf("\nMax difference: %.2e (tolerance: %.2e)\n", max_diff, tolerance); + printf("\nMax difference S-G: %.2e (tolerance: %.2e)\n", max_std_seg, tolerance); + if (torch_success) { + printf("Max difference S-T: %.2e\n", max_std_torch); + printf("Max difference G-T: %.2e\n", max_seg_torch); + } // Cleanup ggml_free(ctx); - return (max_diff < tolerance) ? 0 : 1; -} \ No newline at end of file + return pass ? 0 : 1; +} From 342211644513e2f882e3d293f6025f6a4eaeab8e Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Tue, 24 Jun 2025 06:19:35 +0800 Subject: [PATCH 2/2] test(flash-attn): update head dimensions in test parameters --- tests/test-flash-attn-state.cpp | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/tests/test-flash-attn-state.cpp b/tests/test-flash-attn-state.cpp index cee710f03dd17..b2910f69a237f 100644 --- a/tests/test-flash-attn-state.cpp +++ b/tests/test-flash-attn-state.cpp @@ -63,17 +63,6 @@ static torch::Tensor ggml_to_torch(ggml_tensor * tensor) { } return torch::from_blob(data.data(), sizes, torch::kFloat32).clone(); } - -// Simple flash attention using PyTorch for verification -static torch::Tensor torch_flash_attention(torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor mask, - float scale) { - auto scores = torch::matmul(Q, K.transpose(-2, -1)) * scale; - if (mask.defined()) { - scores = scores + mask; - } - auto attn = torch::softmax(scores, -1); - return torch::matmul(attn, V); -} #endif // LLAMA_TORCH_AVAILABLE static void print_tensor_info(const char * name, ggml_tensor * tensor) { @@ -140,8 +129,8 @@ int main() { // Test parameters const int head_dim = 32; - const int n_heads = 8; - const int n_kv_heads = 4; + const int n_heads = 32; + const int n_kv_heads = 8; const int seq_len = 2; const int kv_len = 4; // Will be split into segments const int n_threads = 4;