From 6af2e427520a5919155bd44baebedb09cf406396 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Thu, 25 Sep 2025 11:26:33 +0800 Subject: [PATCH] issue/477 - Cambricon MLU NeoX Added NeoX support to Cambricon RoPE; Added a missing argument in the profiling script; --- src/infiniop/ops/rope/bang/rope_bang.mlu | 7 +- .../ops/rope/bang/rope_bang_kernel.mlu | 91 +++++++++++++------ test/infiniop/rope.py | 4 +- 3 files changed, 68 insertions(+), 34 deletions(-) diff --git a/src/infiniop/ops/rope/bang/rope_bang.mlu b/src/infiniop/ops/rope/bang/rope_bang.mlu index 789319f5d..b77e32d6c 100644 --- a/src/infiniop/ops/rope/bang/rope_bang.mlu +++ b/src/infiniop/ops/rope/bang/rope_bang.mlu @@ -21,10 +21,6 @@ infiniStatus_t Descriptor::create( auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc, algo); CHECK_RESULT(info); - if (algo != INFINIOP_ROPE_ALGO_GPT_J) { - return INFINI_STATUS_NOT_IMPLEMENTED; - } - // Create descriptor *desc_ptr = new Descriptor( info.take(), @@ -62,7 +58,8 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info, y, x, pos_ids, sin_table, cos_table, dimx, dimy, table_dim, info.y_stride_seqlen, info.y_stride_nhead, - info.x_stride_seqlen, info.x_stride_nhead); + info.x_stride_seqlen, info.x_stride_nhead, + info.algo); cnrtQueueSync(queue); diff --git a/src/infiniop/ops/rope/bang/rope_bang_kernel.mlu b/src/infiniop/ops/rope/bang/rope_bang_kernel.mlu index 960beb15f..fde035b4e 100644 --- a/src/infiniop/ops/rope/bang/rope_bang_kernel.mlu +++ b/src/infiniop/ops/rope/bang/rope_bang_kernel.mlu @@ -1,4 +1,5 @@ #include "../../../devices/bang/common_bang.h" +#include "rope_bang.h" __nram__ char nram_buffer[NRAM_MAX_SIZE]; @@ -11,7 +12,9 @@ __mlu_device__ void calculateRope( Tdata *input_0, Tdata *input_1, Tdata *input_cache, int theta_index, int out_index, int in_index, int chunk_size, int half_chunk_size, int data_segsize, - int src_load_stride, int dst_load_stride, int src_write_stride, int dst_write_stride) { + int src_load_stride, int dst_load_stride, int src_write_stride, int dst_write_stride, + bool is_gpt_j_style) { + // Load sin/cos data __memcpy(sin_cache, sin_table + theta_index, half_chunk_size * sizeof(Tdata), GDRAM2NRAM); __memcpy(cos_cache, cos_table + theta_index, half_chunk_size * sizeof(Tdata), GDRAM2NRAM); @@ -19,11 +22,18 @@ __mlu_device__ void calculateRope( // Load input data __memcpy(input_cache, in + in_index, chunk_size * sizeof(Tdata), GDRAM2NRAM); - // Split input into even and odd positions - __memcpy(input_0, input_cache, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1); - __memcpy(input_1, input_cache + 1, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1); + if (is_gpt_j_style) { + // GPT-J: (x0, x1), (x2, x3), ... + // Split input into even and odd positions + __memcpy(input_0, input_cache, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1); + __memcpy(input_1, input_cache + 1, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1); + } else { + // GPT-NeoX: (x0...xd/2-1), (xd/2...xd-1) + __memcpy(input_0, input_cache, half_chunk_size * sizeof(Tdata), NRAM2NRAM); + __memcpy(input_1, input_cache + half_chunk_size, half_chunk_size * sizeof(Tdata), NRAM2NRAM); + } - // Compute even positions: y0 = x0 * cos - x1 * sin and y1 = x0 * sin + x1 * cos + // Compute rotations __bang_mul(x0cos, input_0, cos_cache, half_chunk_size); __bang_mul(x1sin, input_1, sin_cache, half_chunk_size); __bang_mul(x0sin, input_0, sin_cache, half_chunk_size); @@ -31,9 +41,15 @@ __mlu_device__ void calculateRope( __bang_sub(input_0, x0cos, x1sin, half_chunk_size); __bang_add(input_1, x0sin, x1cos, half_chunk_size); - // Interleave results back into output buffer - __memcpy(input_cache, input_0, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1); - __memcpy(input_cache + 1, input_1, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1); + if (is_gpt_j_style) { + // GPT-J + __memcpy(input_cache, input_0, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1); + __memcpy(input_cache + 1, input_1, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1); + } else { + // GPT-NeoX + __memcpy(input_cache, input_0, half_chunk_size * sizeof(Tdata), NRAM2NRAM); + __memcpy(input_cache + half_chunk_size, input_1, half_chunk_size * sizeof(Tdata), NRAM2NRAM); + } // Write back results __memcpy(out + out_index, input_cache, chunk_size * sizeof(Tdata), NRAM2GDRAM); @@ -52,22 +68,42 @@ __mlu_global__ void ropeKernel( ptrdiff_t y_stride_seqlen, ptrdiff_t y_stride_nhead, ptrdiff_t x_stride_seqlen, - ptrdiff_t x_stride_nhead) { + ptrdiff_t x_stride_nhead, + infiniopRoPEAlgo_t algo) { + + const bool is_gpt_j_style = (algo == INFINIOP_ROPE_ALGO_GPT_J); // Calculate available NRAM space after alignment - const size_t nram_usable = NRAM_MAX_SIZE - (ALIGN_SIZE * 9); // 9 buffers need alignment + const size_t nram_usable = NRAM_MAX_SIZE - (ALIGN_SIZE * 9); const size_t max_chunk_elements = nram_usable / (9 * sizeof(Tdata)); // Key variables that determine execution path const bool use_pos_ids_buffer = (seqlen * sizeof(Tindex) <= (nram_usable / 2)); - const int half_chunk_size = std::min((int)(max_chunk_elements / 2), (int)table_dim); - // Common stride configurations - const int data_segsize = sizeof(Tdata); - const int src_load_stride = 2 * sizeof(Tdata); - const int dst_load_stride = 1 * sizeof(Tdata); - const int src_write_stride = 1 * sizeof(Tdata); - const int dst_write_stride = 2 * sizeof(Tdata); + int half_chunk_size; + if (is_gpt_j_style) { + half_chunk_size = std::min((int)(max_chunk_elements / 2), (int)table_dim); + } else { + half_chunk_size = std::min((int)(max_chunk_elements / 2), (int)table_dim); + } + + int data_segsize, src_load_stride, dst_load_stride, src_write_stride, dst_write_stride; + + if (is_gpt_j_style) { + // GPT-J + data_segsize = sizeof(Tdata); + src_load_stride = 2 * sizeof(Tdata); + dst_load_stride = 1 * sizeof(Tdata); + src_write_stride = 1 * sizeof(Tdata); + dst_write_stride = 2 * sizeof(Tdata); + } else { + // GPT-NeoX + data_segsize = half_chunk_size * sizeof(Tdata); + src_load_stride = 1 * sizeof(Tdata); + dst_load_stride = 1 * sizeof(Tdata); + src_write_stride = 1 * sizeof(Tdata); + dst_write_stride = 1 * sizeof(Tdata); + } // Task distribution const int batch_volume = seqlen * nhead; @@ -100,29 +136,29 @@ __mlu_global__ void ropeKernel( // Main processing loop for (int i = task_start_idx; i < task_start_idx + actual_tasks; i++) { - // Calculate output and input indices int seq_idx = i / nhead; int head_idx = i % nhead; - // Output indices (y) int out_offset = seq_idx * y_stride_seqlen + head_idx * y_stride_nhead; - - // Input indices (x) int in_offset = seq_idx * x_stride_seqlen + head_idx * x_stride_nhead; - // Get position index Tindex pos_idx = use_pos_ids_buffer ? srcP[seq_idx] : pos_ids[seq_idx]; int rot_offset = pos_idx * table_dim; - // Process in chunks that fit in NRAM int processed = 0; while (processed < table_dim) { - // Calculate current chunk size int current_half_chunk = std::min(half_chunk_size, table_dim - processed); int current_chunk_size = 2 * current_half_chunk; int theta_offset = rot_offset + processed; - int dst_offset = out_offset + processed * 2; - int src_offset = in_offset + processed * 2; + + int dst_offset, src_offset; + if (is_gpt_j_style) { + dst_offset = out_offset + processed * 2; + src_offset = in_offset + processed * 2; + } else { + dst_offset = out_offset + processed; + src_offset = in_offset + processed; + } // Set up NRAM buffers for this chunk char *chunk_base = aligned_nram; @@ -143,7 +179,8 @@ __mlu_global__ void ropeKernel( theta_offset, dst_offset, src_offset, current_chunk_size, current_half_chunk, data_segsize, - src_load_stride, dst_load_stride, src_write_stride, dst_write_stride); + src_load_stride, dst_load_stride, src_write_stride, dst_write_stride, + is_gpt_j_style); processed += current_half_chunk; } diff --git a/test/infiniop/rope.py b/test/infiniop/rope.py index 9fa46740c..040f386c7 100644 --- a/test/infiniop/rope.py +++ b/test/infiniop/rope.py @@ -97,7 +97,6 @@ def _torch_rope(sin, cos, t1, t2): return t_out_1, t_out_2 - dh = t.shape[-1] dt = t.dtype assert dh % 2 == 0, "Embedding dimension must be even." @@ -111,7 +110,7 @@ def _torch_rope(sin, cos, t1, t2): ans[..., 0::2] = t_out_even.to(dt) ans[..., 1::2] = t_out_odd.to(dt) else: - half_dim = dh // 2 + half_dim = dh // 2 t_first = t[..., :half_dim] t_second = t[..., half_dim:] @@ -232,6 +231,7 @@ def lib_rope(): sin_table.torch_tensor(), cos_table.torch_tensor(), device, + algo, ), device, NUM_PRERUN,