-
Notifications
You must be signed in to change notification settings - Fork 287
Add top_logprobs Support #1124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add top_logprobs Support #1124
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,8 @@ | |
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
| MAX_TOP_K_LOGPROBS = 20 | ||
|
|
||
|
|
||
| class FinishStatus(ctypes.Structure): | ||
| _pack_ = 4 | ||
|
|
@@ -170,6 +172,7 @@ def init( | |
| self.input_len = len(prompt_ids) | ||
| self.alloc_shm_numpy_len = self.input_len + self.sample_params.max_new_tokens + 1024 # + 1024 for safe | ||
| self.create_logprobs_shm_array() | ||
| self.create_top_logprobs_shm_array() | ||
| self.create_prompt_ids_shm_array() | ||
| self.chunked_prefill_size = chunked_prefill_size | ||
| self.shm_prompt_ids.arr[0 : len(prompt_ids)] = prompt_ids | ||
|
|
@@ -218,13 +221,35 @@ def create_logprobs_shm_array(self): | |
| self.shm_logprobs.create_shm() | ||
| return | ||
|
|
||
| def create_top_logprobs_shm_array(self): | ||
| service_uni_name = get_unique_server_name() | ||
| name_ids = f"{service_uni_name}_shm_top_logprobs_ids_{self.index_in_shm_mem}" | ||
| self.shm_top_logprobs_ids = ShmArray(name_ids, (self.alloc_shm_numpy_len, MAX_TOP_K_LOGPROBS), dtype=np.int32) | ||
| self.shm_top_logprobs_ids.create_shm() | ||
|
|
||
| name_val = f"{service_uni_name}_shm_top_logprobs_val_{self.index_in_shm_mem}" | ||
| self.shm_top_logprobs_val = ShmArray(name_val, (self.alloc_shm_numpy_len, MAX_TOP_K_LOGPROBS), dtype=np.float32) | ||
| self.shm_top_logprobs_val.create_shm() | ||
| return | ||
|
|
||
| def link_logprobs_shm_array(self): | ||
| service_uni_name = get_unique_server_name() | ||
| name = f"{service_uni_name}_shm_logprobs_{self.index_in_shm_mem}" | ||
| self.shm_logprobs = ShmArray(name, (self.alloc_shm_numpy_len,), dtype=np.float32) | ||
| self.shm_logprobs.link_shm() | ||
| return | ||
|
|
||
| def link_top_logprobs_shm_array(self): | ||
| service_uni_name = get_unique_server_name() | ||
| name_ids = f"{service_uni_name}_shm_top_logprobs_ids_{self.index_in_shm_mem}" | ||
| self.shm_top_logprobs_ids = ShmArray(name_ids, (self.alloc_shm_numpy_len, MAX_TOP_K_LOGPROBS), dtype=np.int32) | ||
| self.shm_top_logprobs_ids.link_shm() | ||
|
|
||
| name_val = f"{service_uni_name}_shm_top_logprobs_val_{self.index_in_shm_mem}" | ||
| self.shm_top_logprobs_val = ShmArray(name_val, (self.alloc_shm_numpy_len, MAX_TOP_K_LOGPROBS), dtype=np.float32) | ||
| self.shm_top_logprobs_val.link_shm() | ||
| return | ||
|
Comment on lines
+224
to
+251
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The methods |
||
|
|
||
| def get_prompt_ids(self): | ||
| return self.shm_prompt_ids.arr[: self.input_len].tolist() | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation decodes token IDs one by one inside a loop. This can be inefficient, especially for a large number of top logprobs. You can improve performance by collecting all token IDs and using
tokenizer.batch_decodeto process them in a single call.