diff --git a/mllm/backends/cpu/ops/CausalMaskOp.cpp b/mllm/backends/cpu/ops/CausalMaskOp.cpp index 074ca019..56fc0b7f 100644 --- a/mllm/backends/cpu/ops/CausalMaskOp.cpp +++ b/mllm/backends/cpu/ops/CausalMaskOp.cpp @@ -50,13 +50,14 @@ void CPUCausalMaskOp::forward(const std::vector& inputs, std::vector& inputs, std::vector& inputs, std::vector& inputs, std::vector& inputs, std::vector& inputs, std::vector 0) { memcpy(o_ptr + row_offset, i_ptr + row_offset, copy_count * sizeof(mllm_fp16_t)); } + + mllm_fp16_t* fill_start = o_ptr + row_offset + copy_count; + for (size_t i = 0; i < fill_count; ++i) { fill_start[i] = -65500.f; } + } #endif } else { // Sliding window causal mask @@ -196,7 +224,7 @@ void CPUCausalMaskOp::forward(const std::vector& inputs, std::vector& inputs, std::vector thread_cnt_) { tasks_[task_slot_idx].first = { - .start = 0, - .end = thread_cnt_, - .step = 1, .func = [tiles_num, &task, &true_idx, this](int thread_idx) { for (int v = thread_idx; v < tiles_num; v += thread_cnt_) { task.func(true_idx[v]); } }, + .start = 0, + .end = thread_cnt_, + .step = 1, }; tiles_num = thread_cnt_; } else { tasks_[task_slot_idx].first = { + .func = [tiles_num, &task, &true_idx, this](int thread_idx) { task.func(true_idx[thread_idx]); }, .start = 0, .end = tiles_num, .step = 1, - .func = [tiles_num, &task, &true_idx, this](int thread_idx) { task.func(true_idx[thread_idx]); }, }; } { diff --git a/mllm/models/minicpm_o2_6/streaming_generation.hpp b/mllm/models/minicpm_o2_6/streaming_generation.hpp index 4de6ecc5..eb382e72 100644 --- a/mllm/models/minicpm_o2_6/streaming_generation.hpp +++ b/mllm/models/minicpm_o2_6/streaming_generation.hpp @@ -96,7 +96,7 @@ class StreamingGenerator { config_(config) { // Configure chunk generation models::ChunkGenerationConfig chunk_config{ - .chunk_size = 5, .max_new_tokens = 10, .do_sample = false, .save_first_chunk_hidden_states = true}; + .max_new_tokens = 10, .chunk_size = 5, .do_sample = false, .save_first_chunk_hidden_states = true}; // Add EOS tokens for MiniCPMO auto eos_ids = tokenizer_.convert2Ids({L"<|im_end|>"}); diff --git a/tests/cpu/CausalMaskOpTest.hpp b/tests/cpu/CausalMaskOpTest.hpp new file mode 100644 index 00000000..9969d36c --- /dev/null +++ b/tests/cpu/CausalMaskOpTest.hpp @@ -0,0 +1,66 @@ +#pragma once + +#include + +#include "KernelTestHelper.hpp" +#include "mllm/mllm.hpp" +#include "mllm/nn/layers/CausalMask.hpp" + +class CausalMaskOpTest : public KernelTest { + public: + void SetUp() override { + KernelTest::SetUp(); + mask_.to(mllm::kCPU); + } + + mllm::test::AllCloseResult runScenario(int B, int H, int S, int D) { + using namespace mllm; // NOLINT + const int64_t total = static_cast(B) * H * S * D; + auto input = Tensor::arange(0, static_cast(total), 1, kFloat32, kCPU).view({B, H, S, D}); + auto output = mask_(input); + auto expected = buildExpectedTensor(input); + auto result = test::allClose(expected, output); + if (!result) { + mllm::print(result); + mllm::print(expected); + mllm::print(output); + } + return result; + } + + private: + static mllm::Tensor buildExpectedTensor(const mllm::Tensor& input) { + using namespace mllm; // NOLINT + auto shape = input.shape(); + const int B = shape[0]; + const int H = shape[1]; + const int S = shape[2]; + const int D = shape[3]; + auto expected = Tensor::zeros(shape, kFloat32, kCPU); + + const float* in_ptr = input.ptr(); + float* exp_ptr = expected.ptr(); + const int context_offset = std::max(0, D - S); + const float mask_value = -1e10f; + + for (int b = 0; b < B; ++b) { + for (int h = 0; h < H; ++h) { + for (int s = 0; s < S; ++s) { + const int allowed = std::min(D, context_offset + s + 1); + for (int d = 0; d < D; ++d) { + const int64_t idx = (((static_cast(b) * H) + h) * S + s) * D + d; + if (d < allowed) { + exp_ptr[idx] = in_ptr[idx]; + } else { + exp_ptr[idx] = mask_value; + } + } + } + } + } + return expected; + } + + mllm::nn::CausalMask mask_; +}; + diff --git a/tests/cpu/KernelTest.cpp b/tests/cpu/KernelTest.cpp index fb595c79..9f8d613e 100644 --- a/tests/cpu/KernelTest.cpp +++ b/tests/cpu/KernelTest.cpp @@ -533,6 +533,25 @@ TEST_F(ElementwiseKernelTest, DivScalarInt32) { true); } +//===----------------------------------------------------------------------===// +// CausalMaskOp +//===----------------------------------------------------------------------===// +#include "CausalMaskOpTest.hpp" +TEST_F(CausalMaskOpTest, PrefillScenario) { + auto result = runScenario(1, 1, 4, 4); + EXPECT_TRUE(result.is_close); +} + +TEST_F(CausalMaskOpTest, DecodeScenario) { + auto result = runScenario(1, 1, 1, 6); + EXPECT_TRUE(result.is_close); +} + +TEST_F(CausalMaskOpTest, AppendScenario) { + auto result = runScenario(2, 3, 3, 7); + EXPECT_TRUE(result.is_close); +} + //===----------------------------------------------------------------------===// // GELU //===----------------------------------------------------------------------===//