Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,144 +17,201 @@
#include <sys/ipc.h>
#include <sys/msg.h>
#include <sys/types.h>
#include "../speculate_msg.h"
#include "paddle/extension.h"

#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif

#define MAX_BSZ 256

// #define SAVE_WITH_OUTPUT_DEBUG
#define MAX_DRAFT_TOKENS 6
struct msgdata {
long mtype;
int mtext[2 + MAX_BSZ + MAX_BSZ * MAX_DRAFT_TOKENS]; // stop_flag, token_num, tokens
};

void MTPSaveFirstToken(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
int msg_queue_id,
bool save_each_rank) {
if (!save_each_rank && rank_id > 0) {
return;
}
int x_dim = x.shape()[1];
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
int64_t* x_data = x_cpu.data<int64_t>();
static struct msgdata msg_sed;

if (const char* inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(
inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
std::stoi(inference_msg_queue_id_env_str);
const paddle::Tensor& not_need_stop,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& prompt_lens,
const paddle::Tensor& step_idx,
int64_t rank_id,
int msg_queue_id,
bool save_each_rank,
bool skip_chunk_prefill) {
if (!save_each_rank && rank_id > 0) {
return;
}
int x_dim = x.shape()[1];
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
int64_t* x_data = x_cpu.data<int64_t>();

auto seq_lens_decoder_cpu =
seq_lens_decoder.copy_to(paddle::CPUPlace(), true);
int* seq_lens_decoder_data = seq_lens_decoder_cpu.data<int>();

auto prompt_lens_cpu = prompt_lens.copy_to(paddle::CPUPlace(), true);
int64_t* prompt_lens_data = prompt_lens_cpu.data<int64_t>();

auto step_idx_cpu = step_idx.copy_to(paddle::CPUPlace(), true);
int64_t* step_idx_data = step_idx_cpu.data<int64_t>();

static struct speculate_msgdata msg_sed;

if (const char* inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
std::stoi(inference_msg_queue_id_env_str);
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
<< inference_msg_queue_id_from_env << std::endl;
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
<< inference_msg_queue_id_from_env << std::endl;
#endif
msg_queue_id = inference_msg_queue_id_from_env;
}
msg_queue_id = inference_msg_queue_id_from_env;
}

static key_t key = ftok("./", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);

msg_sed.mtype = 1;
bool not_need_stop_data = not_need_stop.data<bool>()[0];
int inference_msg_id_from_env = 1;
if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) {
std::string inference_msg_id_env_str(inference_msg_id_env_p);
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
if (inference_msg_id_from_env == 2) {
// 2 and -2 is preserve for no-output indication.
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be 2, please use other number.");
}
if (inference_msg_id_from_env < 0) {
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be negative, please use other "
"number.");
}
static key_t key = ftok("./", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);

msg_sed.mtype = 1;
bool not_need_stop_data = not_need_stop.data<bool>()[0];
int inference_msg_id_from_env = 1;
if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) {
std::string inference_msg_id_env_str(inference_msg_id_env_p);
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
if (inference_msg_id_from_env == 2) {
// 2 and -2 is preserve for no-output indication.
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be 2, please use other number.");
}
if (inference_msg_id_from_env < 0) {
throw std::runtime_error(
" INFERENCE_MSG_ID cannot be negative, please use other "
"number.");
}

#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env
<< std::endl;
std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env
<< std::endl;
#endif
} else {
} else {
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout
<< "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default."
<< std::endl;
std::cout << "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default."
<< std::endl;
#endif
}
}
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "save_output_key: " << key << std::endl;
std::cout << "save msgid: " << msgid << std::endl;
#endif
msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env
: -inference_msg_id_from_env;
int bsz = x.shape()[0];
msg_sed.mtext[1] = bsz;
for (int i = 0; i < bsz; i++) {
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "save_output_key: " << key << std::endl;
std::cout << "save msgid: " << msgid << std::endl;
printf("bid: %d. 1: %d. 2: %d.\n",
i,
(int)x_data[i * x_dim],
(int)x_data[i * x_dim + 1]);
#endif
msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env
: -inference_msg_id_from_env;
int bsz = x.shape()[0];
msg_sed.mtext[1] = bsz;
for (int i = 0; i < bsz; i++) {
if ((skip_chunk_prefill &&
seq_lens_decoder_data[i] < prompt_lens_data[i]) ||
step_idx_data[i] == 0) {
msg_sed.mtext[i + 2] = 0;
#ifdef SAVE_WITH_OUTPUT_DEBUG
printf("bid: %d. 1: %d. 2: %d.\n", i, (int)x_data[i * x_dim], (int)x_data[i * x_dim + 1]);
printf("bid[%d] skip save mtp output \n", i);
#endif
msg_sed.mtext[i + 2] = 2;
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ] = (int)x_data[i * x_dim];
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ] = (int)x_data[i * x_dim + 1];
continue;
} else if (step_idx_data[i] == 1) {
#ifdef SAVE_WITH_OUTPUT_DEBUG
printf("mtext[%d]:%d. mtext[%d]:%d. \n", i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ,
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ],
i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ,
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ]);
printf("bid[%d] save mtp tokens \n", i);
#endif
msg_sed.mtext[i + 2] = 2;
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ] =
(int)x_data[i * x_dim];
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ] =
(int)x_data[i * x_dim + 1];
}

#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "msg data: ";
for (int i = 0; i < bsz; i++) {
std::cout << " " << (int)x_data[2*i] << " ";
std::cout << " " << (int)x_data[2*i + 1];
printf("mtext[%d]:%d. mtext[%d]:%d. \n",
i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ,
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ],
i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ,
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ]);
#endif
}

}
std::cout << std::endl;
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "msg data: ";
for (int i = 0; i < bsz; i++) {
std::cout << " " << (int)x_data[2 * i] << " ";
std::cout << " " << (int)x_data[2 * i + 1];
}
std::cout << std::endl;
#endif
if ((msgsnd(msgid,
&msg_sed,
(2 + MAX_BSZ + MAX_BSZ * MAX_DRAFT_TOKENS) * 4, 0)) == -1) {
printf("full msg buffer\n");
}
return;
if ((msgsnd(msgid,
&msg_sed,
(2 + MAX_BSZ + MAX_BSZ * MAX_DRAFT_TOKENS) * 4,
0)) == -1) {
printf("full msg buffer\n");
}
return;
}

void MTPSaveFirstTokenStatic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
bool save_each_rank) {
MTPSaveFirstToken(x, not_need_stop, rank_id, 1, save_each_rank);
const paddle::Tensor& not_need_stop,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& prompt_lens,
const paddle::Tensor& step_idx,
int64_t rank_id,
bool save_each_rank,
bool skip_chunk_prefill) {
MTPSaveFirstToken(x,
not_need_stop,
seq_lens_decoder,
prompt_lens,
step_idx,
rank_id,
1,
save_each_rank,
skip_chunk_prefill);
}

void MTPSaveFirstTokenDynamic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
int msg_queue_id,
bool save_each_rank) {
MTPSaveFirstToken(x, not_need_stop, rank_id, msg_queue_id, save_each_rank);
const paddle::Tensor& not_need_stop,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& prompt_lens,
const paddle::Tensor& step_idx,
int64_t rank_id,
int msg_queue_id,
bool save_each_rank,
bool skip_chunk_prefill) {
MTPSaveFirstToken(x,
not_need_stop,
seq_lens_decoder,
prompt_lens,
step_idx,
rank_id,
msg_queue_id,
save_each_rank,
skip_chunk_prefill);
}

PD_BUILD_STATIC_OP(mtp_save_first_token)
.Inputs({"x", "not_need_stop"})
.Inputs(
{"x", "not_need_stop", "seq_lens_decoder", "prompt_lens", "step_idx"})
.Attrs({"rank_id: int64_t",
"save_each_rank: bool"})
"save_each_rank: bool",
"skip_chunk_prefill: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(MTPSaveFirstTokenStatic));

PD_BUILD_STATIC_OP(mtp_save_first_token_dynamic)
.Inputs({"x", "not_need_stop"})
.Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool"})
.Inputs(
{"x", "not_need_stop", "seq_lens_decoder", "prompt_lens", "step_idx"})
.Attrs({"rank_id: int64_t",
"msg_queue_id: int",
"save_each_rank: bool",
"skip_chunk_prefill: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(MTPSaveFirstTokenDynamic));
Loading
Loading