-
Notifications
You must be signed in to change notification settings - Fork 0
Add PyTorch comparison for flash-attn state test #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add PyTorch comparison for flash-attn state test #11
Conversation
Zijie-Tian
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR extends test-flash-attn-state.cpp to include an optional PyTorch-based verification alongside the existing segmented and standard flash-attention implementations, and enriches the result comparison with an element-wise table.
- Added PyTorch headers, tensor conversion, and verification logic under
LLAMA_TORCH_AVAILABLE - Replaced the single diff-based comparison with manual loops computing max differences across standard, segmented, and PyTorch outputs
- Introduced a detailed element-wise comparison table for the first 128 elements and unified print formatting
Comments suppressed due to low confidence (1)
tests/test-flash-attn-state.cpp:464
- [nitpick] The comment label still reads 'Test 3' for the comparison section, which now follows two other 'Test 3' sections; consider renumbering it to 'Test 4' for clarity.
// Test 3: Compare Results
* oai moe * compat with new checkpoint * add attn sink impl * add rope scaling yarn * logits match with latest transformers code * wip chat template * rm trailing space * use ggml_scale_bias * rm redundant is_swa_all * convert interleaved gate_up * graph : fix activation function to match reference (#7) * vocab : handle o200k_harmony special tokens * ggml : add attention sinks support (#1) * llama : add attn sinks * ggml : add attn sinks * cuda : add attn sinks * vulkan : add support for sinks in softmax remove unnecessary return * ggml : add fused swiglu_oai op (#11) * ggml : add fused swiglu_oai op * Update ggml/src/ggml-cpu/ops.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * update CUDA impl * cont : metal impl * add vulkan impl * test-backend-ops : more test cases, clean up * llama : remove unfused impl * remove extra lines --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: slaren <slarengh@gmail.com> * repack mxfp4 upon conversion * clean up a bit * enable thinking * add quick hack to render only some special tokens * fix bf16 conversion * remove vocab hack * webui ok * support chat parsing for gpt-oss * fix webui * direct mapping mxfp4, FINALLY * force using mxfp4 * properly use lazy tensor * ggml : add mxfp4 ggml : use e8m0 conversion instead of powf Co-authored-by: Diego Devesa <slarengh@gmail.com> change kvalues_mxfp4 table to match e2m1 (#6) metal : remove quantization for now (not used) cuda : fix disabled CUDA graphs due to ffn moe bias vulkan : add support for mxfp4 cont : add cm2 dequant * ggml : add ggml_add_id (#13) * ggml : add ggml_add_id * add cuda impl * llama : add weight support check for add_id * perf opt * add vulkan impl * rename cuda files * add metal impl * allow in-place ggml_add_id * llama : keep biases on CPU with --cpu-moe * llama : fix compile error ggml-ci * cuda : add fallback for __nv_cvt_e8m0_to_bf16raw ggml-ci * cleanup ggml-ci * sycl : fix supports_op for MXFP4 ggml-ci * fix Unknown reasoning format * ggml-cpu : fix AVX build ggml-ci * fix hip build ggml-ci * cuda : add mxfp4 dequantization support for cuBLAS ggml-ci * ggml-cpu : fix mxfp4 fallback definitions for some architectures ggml-ci * cuda : fix version required for __nv_cvt_e8m0_to_bf16raw --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co> Co-authored-by: slaren <slarengh@gmail.com>
Summary
test-flash-attn-state.cppwith optional PyTorch verificationTesting
cmake -G Ninja -D GGML_GRAPH_PROFILER=ON -D GGML_CUDA=OFF -D GGML_TMAC=OFF -D LLAMA_TORCH=ON -B build-x86_64cmake --build build-x86_64 --config Release -j12./build-x86_64/bin/test-flash-attn-statehttps://chatgpt.com/codex/tasks/task_e_6859cc8cc3b08332ac84da4077269746