Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions src/infiniop/ops/rope/bang/rope_bang.mlu
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@ infiniStatus_t Descriptor::create(
auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc, algo);
CHECK_RESULT(info);

if (algo != INFINIOP_ROPE_ALGO_GPT_J) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}

// Create descriptor
*desc_ptr = new Descriptor(
info.take(),
Expand Down Expand Up @@ -62,7 +58,8 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
y, x, pos_ids, sin_table, cos_table,
dimx, dimy, table_dim,
info.y_stride_seqlen, info.y_stride_nhead,
info.x_stride_seqlen, info.x_stride_nhead);
info.x_stride_seqlen, info.x_stride_nhead,
info.algo);

cnrtQueueSync(queue);

Expand Down
91 changes: 64 additions & 27 deletions src/infiniop/ops/rope/bang/rope_bang_kernel.mlu
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "../../../devices/bang/common_bang.h"
#include "rope_bang.h"

__nram__ char nram_buffer[NRAM_MAX_SIZE];

Expand All @@ -11,29 +12,44 @@ __mlu_device__ void calculateRope(
Tdata *input_0, Tdata *input_1, Tdata *input_cache,
int theta_index, int out_index, int in_index,
int chunk_size, int half_chunk_size, int data_segsize,
int src_load_stride, int dst_load_stride, int src_write_stride, int dst_write_stride) {
int src_load_stride, int dst_load_stride, int src_write_stride, int dst_write_stride,
bool is_gpt_j_style) {

// Load sin/cos data
__memcpy(sin_cache, sin_table + theta_index, half_chunk_size * sizeof(Tdata), GDRAM2NRAM);
__memcpy(cos_cache, cos_table + theta_index, half_chunk_size * sizeof(Tdata), GDRAM2NRAM);

// Load input data
__memcpy(input_cache, in + in_index, chunk_size * sizeof(Tdata), GDRAM2NRAM);

// Split input into even and odd positions
__memcpy(input_0, input_cache, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1);
__memcpy(input_1, input_cache + 1, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1);
if (is_gpt_j_style) {
// GPT-J: (x0, x1), (x2, x3), ...
// Split input into even and odd positions
__memcpy(input_0, input_cache, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1);
__memcpy(input_1, input_cache + 1, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1);
} else {
// GPT-NeoX: (x0...xd/2-1), (xd/2...xd-1)
__memcpy(input_0, input_cache, half_chunk_size * sizeof(Tdata), NRAM2NRAM);
__memcpy(input_1, input_cache + half_chunk_size, half_chunk_size * sizeof(Tdata), NRAM2NRAM);
}

// Compute even positions: y0 = x0 * cos - x1 * sin and y1 = x0 * sin + x1 * cos
// Compute rotations
__bang_mul(x0cos, input_0, cos_cache, half_chunk_size);
__bang_mul(x1sin, input_1, sin_cache, half_chunk_size);
__bang_mul(x0sin, input_0, sin_cache, half_chunk_size);
__bang_mul(x1cos, input_1, cos_cache, half_chunk_size);
__bang_sub(input_0, x0cos, x1sin, half_chunk_size);
__bang_add(input_1, x0sin, x1cos, half_chunk_size);

// Interleave results back into output buffer
__memcpy(input_cache, input_0, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1);
__memcpy(input_cache + 1, input_1, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1);
if (is_gpt_j_style) {
// GPT-J
__memcpy(input_cache, input_0, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1);
__memcpy(input_cache + 1, input_1, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1);
} else {
// GPT-NeoX
__memcpy(input_cache, input_0, half_chunk_size * sizeof(Tdata), NRAM2NRAM);
__memcpy(input_cache + half_chunk_size, input_1, half_chunk_size * sizeof(Tdata), NRAM2NRAM);
}

// Write back results
__memcpy(out + out_index, input_cache, chunk_size * sizeof(Tdata), NRAM2GDRAM);
Expand All @@ -52,22 +68,42 @@ __mlu_global__ void ropeKernel(
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
ptrdiff_t x_stride_nhead,
infiniopRoPEAlgo_t algo) {

const bool is_gpt_j_style = (algo == INFINIOP_ROPE_ALGO_GPT_J);

// Calculate available NRAM space after alignment
const size_t nram_usable = NRAM_MAX_SIZE - (ALIGN_SIZE * 9); // 9 buffers need alignment
const size_t nram_usable = NRAM_MAX_SIZE - (ALIGN_SIZE * 9);
const size_t max_chunk_elements = nram_usable / (9 * sizeof(Tdata));

// Key variables that determine execution path
const bool use_pos_ids_buffer = (seqlen * sizeof(Tindex) <= (nram_usable / 2));
const int half_chunk_size = std::min((int)(max_chunk_elements / 2), (int)table_dim);

// Common stride configurations
const int data_segsize = sizeof(Tdata);
const int src_load_stride = 2 * sizeof(Tdata);
const int dst_load_stride = 1 * sizeof(Tdata);
const int src_write_stride = 1 * sizeof(Tdata);
const int dst_write_stride = 2 * sizeof(Tdata);
int half_chunk_size;
if (is_gpt_j_style) {
half_chunk_size = std::min((int)(max_chunk_elements / 2), (int)table_dim);
} else {
half_chunk_size = std::min((int)(max_chunk_elements / 2), (int)table_dim);
}

int data_segsize, src_load_stride, dst_load_stride, src_write_stride, dst_write_stride;

if (is_gpt_j_style) {
// GPT-J
data_segsize = sizeof(Tdata);
src_load_stride = 2 * sizeof(Tdata);
dst_load_stride = 1 * sizeof(Tdata);
src_write_stride = 1 * sizeof(Tdata);
dst_write_stride = 2 * sizeof(Tdata);
} else {
// GPT-NeoX
data_segsize = half_chunk_size * sizeof(Tdata);
src_load_stride = 1 * sizeof(Tdata);
dst_load_stride = 1 * sizeof(Tdata);
src_write_stride = 1 * sizeof(Tdata);
dst_write_stride = 1 * sizeof(Tdata);
}

// Task distribution
const int batch_volume = seqlen * nhead;
Expand Down Expand Up @@ -100,29 +136,29 @@ __mlu_global__ void ropeKernel(

// Main processing loop
for (int i = task_start_idx; i < task_start_idx + actual_tasks; i++) {
// Calculate output and input indices
int seq_idx = i / nhead;
int head_idx = i % nhead;

// Output indices (y)
int out_offset = seq_idx * y_stride_seqlen + head_idx * y_stride_nhead;

// Input indices (x)
int in_offset = seq_idx * x_stride_seqlen + head_idx * x_stride_nhead;

// Get position index
Tindex pos_idx = use_pos_ids_buffer ? srcP[seq_idx] : pos_ids[seq_idx];
int rot_offset = pos_idx * table_dim;

// Process in chunks that fit in NRAM
int processed = 0;
while (processed < table_dim) {
// Calculate current chunk size
int current_half_chunk = std::min<uint32_t>(half_chunk_size, table_dim - processed);
int current_chunk_size = 2 * current_half_chunk;
int theta_offset = rot_offset + processed;
int dst_offset = out_offset + processed * 2;
int src_offset = in_offset + processed * 2;

int dst_offset, src_offset;
if (is_gpt_j_style) {
dst_offset = out_offset + processed * 2;
src_offset = in_offset + processed * 2;
} else {
dst_offset = out_offset + processed;
src_offset = in_offset + processed;
}

// Set up NRAM buffers for this chunk
char *chunk_base = aligned_nram;
Expand All @@ -143,7 +179,8 @@ __mlu_global__ void ropeKernel(
theta_offset, dst_offset, src_offset,
current_chunk_size, current_half_chunk,
data_segsize,
src_load_stride, dst_load_stride, src_write_stride, dst_write_stride);
src_load_stride, dst_load_stride, src_write_stride, dst_write_stride,
is_gpt_j_style);

processed += current_half_chunk;
}
Expand Down
4 changes: 2 additions & 2 deletions test/infiniop/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def _torch_rope(sin, cos, t1, t2):

return t_out_1, t_out_2


dh = t.shape[-1]
dt = t.dtype
assert dh % 2 == 0, "Embedding dimension must be even."
Expand All @@ -111,7 +110,7 @@ def _torch_rope(sin, cos, t1, t2):
ans[..., 0::2] = t_out_even.to(dt)
ans[..., 1::2] = t_out_odd.to(dt)
else:
half_dim = dh // 2
half_dim = dh // 2
t_first = t[..., :half_dim]
t_second = t[..., half_dim:]

Expand Down Expand Up @@ -232,6 +231,7 @@ def lib_rope():
sin_table.torch_tensor(),
cos_table.torch_tensor(),
device,
algo,
),
device,
NUM_PRERUN,
Expand Down
Loading