Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 38 additions & 5 deletions mllm/backends/cpu/ops/CausalMaskOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,14 @@ void CPUCausalMaskOp::forward(const std::vector<Tensor>& inputs, std::vector<Ten

if (!options_.sliding_window) {
// Standard causal mask
#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) && defined(__AVX2__)
#if (defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86)) && defined(__AVX2__)
const __m256 mask_val = _mm256_set1_ps(-1e10f);
for (size_t r = 0; r < S; ++r) {
const size_t row_offset = r * D;
const size_t copy_count = D - S + r + 1;
const size_t fill_count = std::max(D - copy_count, (size_t)0);

memcpy(o_ptr + r * D, i_ptr + r * D, copy_count * sizeof(float));
memcpy(o_ptr + row_offset, i_ptr + row_offset, copy_count * sizeof(float));

float* fill_start = o_ptr + row_offset + copy_count;
size_t avx_iters = fill_count / 8;
Expand All @@ -81,6 +82,17 @@ void CPUCausalMaskOp::forward(const std::vector<Tensor>& inputs, std::vector<Ten
for (size_t i = 0; i < neon_iters; ++i) { vst1q_f32(fill_start + i * 4, mask_val); }
for (size_t i = 0; i < remainder; ++i) { fill_start[neon_iters * 4 + i] = -1e10f; }
}
#else
for (size_t r = 0; r < S; ++r) {
const size_t row_offset = r * D;
const size_t copy_count = D - S + r + 1;
const size_t fill_count = std::max(D - copy_count, (size_t)0);

memcpy(o_ptr + row_offset, i_ptr + row_offset, copy_count * sizeof(float));

float* fill_start = o_ptr + row_offset + copy_count;
for (size_t i = 0; i < fill_count; ++i) { fill_start[i] = -1e10f; }
}
#endif
} else {
// Sliding window causal mask
Expand All @@ -98,7 +110,7 @@ void CPUCausalMaskOp::forward(const std::vector<Tensor>& inputs, std::vector<Ten
const size_t suffix_fill_start_idx = s + 1;
const size_t suffix_fill_count = S - suffix_fill_start_idx;

#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) && defined(__AVX2__)
#if (defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86)) && defined(__AVX2__)
const __m256 mask_val = _mm256_set1_ps(-1e10f);
// Fill prefix
float* prefix_fill_start = o_ptr + row_offset;
Expand All @@ -118,6 +130,11 @@ void CPUCausalMaskOp::forward(const std::vector<Tensor>& inputs, std::vector<Ten
float* suffix_fill_start = o_ptr + row_offset + suffix_fill_start_idx;
for (size_t i = 0; i < suffix_fill_count / 4; ++i) vst1q_f32(suffix_fill_start + i * 4, mask_val);
for (size_t i = (suffix_fill_count / 4) * 4; i < suffix_fill_count; ++i) suffix_fill_start[i] = -1e10f;
#else
float* prefix_fill_start = o_ptr + row_offset;
for (size_t i = 0; i < prefix_fill_count; ++i) { prefix_fill_start[i] = -1e10f; }
float* suffix_fill_start = o_ptr + row_offset + suffix_fill_start_idx;
for (size_t i = 0; i < suffix_fill_count; ++i) { suffix_fill_start[i] = -1e10f; }
#endif
}
}
Expand All @@ -143,7 +160,7 @@ void CPUCausalMaskOp::forward(const std::vector<Tensor>& inputs, std::vector<Ten

if (!options_.sliding_window) {
// Standard causal mask
#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) && defined(__AVX2__) && defined(__F16C__)
#if (defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86)) && defined(__AVX2__) && defined(__F16C__)
const __m256 mask_ps = _mm256_set1_ps(-65500.f);
const __m128i mask_val = _mm256_cvtps_ph(mask_ps, _MM_FROUND_TO_NEAREST_INT);
for (size_t s = 0; s < S; ++s) {
Expand Down Expand Up @@ -178,6 +195,17 @@ void CPUCausalMaskOp::forward(const std::vector<Tensor>& inputs, std::vector<Ten
for (size_t i = 0; i < neon_iters; ++i) { vst1q_f16(fill_start + i * 8, mask_val); }
for (size_t i = 0; i < remainder; ++i) { fill_start[neon_iters * 8 + i] = -65500.f; }
}
#else
for (size_t s = 0; s < S; ++s) {
const size_t row_offset = s * S;
const size_t copy_count = s + 1;
const size_t fill_count = S - copy_count;

if (copy_count > 0) { memcpy(o_ptr + row_offset, i_ptr + row_offset, copy_count * sizeof(mllm_fp16_t)); }

mllm_fp16_t* fill_start = o_ptr + row_offset + copy_count;
for (size_t i = 0; i < fill_count; ++i) { fill_start[i] = -65500.f; }
}
#endif
} else {
// Sliding window causal mask
Expand All @@ -196,7 +224,7 @@ void CPUCausalMaskOp::forward(const std::vector<Tensor>& inputs, std::vector<Ten
const size_t suffix_fill_start_idx = s + 1;
const size_t suffix_fill_count = S - suffix_fill_start_idx;

#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) && defined(__AVX2__) && defined(__F16C__)
#if (defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86)) && defined(__AVX2__) && defined(__F16C__)
const __m256 mask_ps = _mm256_set1_ps(-65500.f);
const __m128i mask_val = _mm256_cvtps_ph(mask_ps, _MM_FROUND_TO_NEAREST_INT);

Expand All @@ -222,6 +250,11 @@ void CPUCausalMaskOp::forward(const std::vector<Tensor>& inputs, std::vector<Ten
mllm_fp16_t* suffix_fill_start = o_ptr + row_offset + suffix_fill_start_idx;
for (size_t i = 0; i < suffix_fill_count / 8; ++i) vst1q_f16(suffix_fill_start + i * 8, mask_val);
for (size_t i = (suffix_fill_count / 8) * 8; i < suffix_fill_count; ++i) suffix_fill_start[i] = -65500.f;
#else
mllm_fp16_t* prefix_fill_start = o_ptr + row_offset;
for (size_t i = 0; i < prefix_fill_count; ++i) { prefix_fill_start[i] = -65500.f; }
mllm_fp16_t* suffix_fill_start = o_ptr + row_offset + suffix_fill_start_idx;
for (size_t i = 0; i < suffix_fill_count; ++i) { suffix_fill_start[i] = -65500.f; }
#endif
}
}
Expand Down
8 changes: 4 additions & 4 deletions mllm/engine/HpcThreadPool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,21 +96,21 @@ void HpcThreadPool::splitTask(HpcThreadPoolTask&& task, int task_slot_idx) {
// 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3
if (tiles_num > thread_cnt_) {
tasks_[task_slot_idx].first = {
.start = 0,
.end = thread_cnt_,
.step = 1,
.func =
[tiles_num, &task, &true_idx, this](int thread_idx) {
for (int v = thread_idx; v < tiles_num; v += thread_cnt_) { task.func(true_idx[v]); }
},
.start = 0,
.end = thread_cnt_,
.step = 1,
};
tiles_num = thread_cnt_;
} else {
tasks_[task_slot_idx].first = {
.func = [tiles_num, &task, &true_idx, this](int thread_idx) { task.func(true_idx[thread_idx]); },
.start = 0,
.end = tiles_num,
.step = 1,
.func = [tiles_num, &task, &true_idx, this](int thread_idx) { task.func(true_idx[thread_idx]); },
};
}
{
Expand Down
2 changes: 1 addition & 1 deletion mllm/models/minicpm_o2_6/streaming_generation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class StreamingGenerator {
config_(config) {
// Configure chunk generation
models::ChunkGenerationConfig chunk_config{
.chunk_size = 5, .max_new_tokens = 10, .do_sample = false, .save_first_chunk_hidden_states = true};
.max_new_tokens = 10, .chunk_size = 5, .do_sample = false, .save_first_chunk_hidden_states = true};

// Add EOS tokens for MiniCPMO
auto eos_ids = tokenizer_.convert2Ids({L"<|im_end|>"});
Expand Down
66 changes: 66 additions & 0 deletions tests/cpu/CausalMaskOpTest.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#pragma once

#include <algorithm>

#include "KernelTestHelper.hpp"
#include "mllm/mllm.hpp"
#include "mllm/nn/layers/CausalMask.hpp"

class CausalMaskOpTest : public KernelTest {
public:
void SetUp() override {
KernelTest::SetUp();
mask_.to(mllm::kCPU);
}

mllm::test::AllCloseResult runScenario(int B, int H, int S, int D) {
using namespace mllm; // NOLINT
const int64_t total = static_cast<int64_t>(B) * H * S * D;
auto input = Tensor::arange(0, static_cast<float>(total), 1, kFloat32, kCPU).view({B, H, S, D});
auto output = mask_(input);
auto expected = buildExpectedTensor(input);
auto result = test::allClose(expected, output);
if (!result) {
mllm::print(result);
mllm::print(expected);
mllm::print(output);
}
return result;
}

private:
static mllm::Tensor buildExpectedTensor(const mllm::Tensor& input) {
using namespace mllm; // NOLINT
auto shape = input.shape();
const int B = shape[0];
const int H = shape[1];
const int S = shape[2];
const int D = shape[3];
auto expected = Tensor::zeros(shape, kFloat32, kCPU);

const float* in_ptr = input.ptr<float>();
float* exp_ptr = expected.ptr<float>();
const int context_offset = std::max(0, D - S);
const float mask_value = -1e10f;

for (int b = 0; b < B; ++b) {
for (int h = 0; h < H; ++h) {
for (int s = 0; s < S; ++s) {
const int allowed = std::min(D, context_offset + s + 1);
for (int d = 0; d < D; ++d) {
const int64_t idx = (((static_cast<int64_t>(b) * H) + h) * S + s) * D + d;
if (d < allowed) {
exp_ptr[idx] = in_ptr[idx];
} else {
exp_ptr[idx] = mask_value;
}
}
}
}
}
return expected;
}

mllm::nn::CausalMask mask_;
};

19 changes: 19 additions & 0 deletions tests/cpu/KernelTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,25 @@ TEST_F(ElementwiseKernelTest, DivScalarInt32) {
true);
}

//===----------------------------------------------------------------------===//
// CausalMaskOp
//===----------------------------------------------------------------------===//
#include "CausalMaskOpTest.hpp"
TEST_F(CausalMaskOpTest, PrefillScenario) {
auto result = runScenario(1, 1, 4, 4);
EXPECT_TRUE(result.is_close);
}

TEST_F(CausalMaskOpTest, DecodeScenario) {
auto result = runScenario(1, 1, 1, 6);
EXPECT_TRUE(result.is_close);
}

TEST_F(CausalMaskOpTest, AppendScenario) {
auto result = runScenario(2, 3, 3, 7);
EXPECT_TRUE(result.is_close);
}

//===----------------------------------------------------------------------===//
// GELU
//===----------------------------------------------------------------------===//
Expand Down