-
Notifications
You must be signed in to change notification settings - Fork 287
Add top_logprobs Support #1124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add top_logprobs Support #1124
Conversation
Summary of ChangesHello @ahmadbasyouni10, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This PR successfully implements the top_logprobs feature. The changes are well-structured and span across the memory, GPU, transport, and API layers to support this new functionality. The implementation correctly uses shared memory for inter-process communication and handles data transfer from GPU to CPU. My review includes a few suggestions to improve code maintainability by reducing duplication and to enhance performance in the API formatting layer. I've also pointed out a minor inconsistency in return value ordering for better code clarity. Overall, this is a solid implementation.
| formatted_top_logprobs = {} | ||
| for item in top_logprobs: | ||
| for t_id, t_prob in item.items(): | ||
| t_text = tokenizer.decode([t_id], skip_special_tokens=False) | ||
| formatted_top_logprobs[t_text] = t_prob |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation decodes token IDs one by one inside a loop. This can be inefficient, especially for a large number of top logprobs. You can improve performance by collecting all token IDs and using tokenizer.batch_decode to process them in a single call.
| formatted_top_logprobs = {} | |
| for item in top_logprobs: | |
| for t_id, t_prob in item.items(): | |
| t_text = tokenizer.decode([t_id], skip_special_tokens=False) | |
| formatted_top_logprobs[t_text] = t_prob | |
| token_ids = [next(iter(item)) for item in top_logprobs] | |
| log_probs = [next(iter(item.values())) for item in top_logprobs] | |
| token_texts = tokenizer.batch_decode( | |
| [[tid] for tid in token_ids], skip_special_tokens=False | |
| ) | |
| formatted_top_logprobs = dict(zip(token_texts, log_probs)) |
| def create_top_logprobs_shm_array(self): | ||
| service_uni_name = get_unique_server_name() | ||
| name_ids = f"{service_uni_name}_shm_top_logprobs_ids_{self.index_in_shm_mem}" | ||
| self.shm_top_logprobs_ids = ShmArray(name_ids, (self.alloc_shm_numpy_len, MAX_TOP_K_LOGPROBS), dtype=np.int32) | ||
| self.shm_top_logprobs_ids.create_shm() | ||
|
|
||
| name_val = f"{service_uni_name}_shm_top_logprobs_val_{self.index_in_shm_mem}" | ||
| self.shm_top_logprobs_val = ShmArray(name_val, (self.alloc_shm_numpy_len, MAX_TOP_K_LOGPROBS), dtype=np.float32) | ||
| self.shm_top_logprobs_val.create_shm() | ||
| return | ||
|
|
||
| def link_logprobs_shm_array(self): | ||
| service_uni_name = get_unique_server_name() | ||
| name = f"{service_uni_name}_shm_logprobs_{self.index_in_shm_mem}" | ||
| self.shm_logprobs = ShmArray(name, (self.alloc_shm_numpy_len,), dtype=np.float32) | ||
| self.shm_logprobs.link_shm() | ||
| return | ||
|
|
||
| def link_top_logprobs_shm_array(self): | ||
| service_uni_name = get_unique_server_name() | ||
| name_ids = f"{service_uni_name}_shm_top_logprobs_ids_{self.index_in_shm_mem}" | ||
| self.shm_top_logprobs_ids = ShmArray(name_ids, (self.alloc_shm_numpy_len, MAX_TOP_K_LOGPROBS), dtype=np.int32) | ||
| self.shm_top_logprobs_ids.link_shm() | ||
|
|
||
| name_val = f"{service_uni_name}_shm_top_logprobs_val_{self.index_in_shm_mem}" | ||
| self.shm_top_logprobs_val = ShmArray(name_val, (self.alloc_shm_numpy_len, MAX_TOP_K_LOGPROBS), dtype=np.float32) | ||
| self.shm_top_logprobs_val.link_shm() | ||
| return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The methods create_top_logprobs_shm_array and link_top_logprobs_shm_array contain a lot of duplicated code. The only difference is calling create_shm() vs link_shm(). This could be refactored into a helper method to reduce redundancy and improve maintainability. For example, a private helper method could handle the initialization of the ShmArray objects, while the public methods would just call the appropriate create_shm() or link_shm() on them.
| next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index) | ||
| next_token_logprobs = torch.log(torch.gather(probs_sort, dim=1, index=sampled_index)) | ||
| return next_token_ids.view(-1), next_token_logprobs.view(-1) | ||
| top_k_logprobs_val, top_k_logprobs_idx = _get_top_logprobs(probs, k=20) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To improve clarity and consistency with the proposed change in _get_top_logprobs, the variable names here should be swapped to reflect the idx, val order. This makes the data flow from the helper function to the return statement of this function more intuitive.
| top_k_logprobs_val, top_k_logprobs_idx = _get_top_logprobs(probs, k=20) | |
| top_k_logprobs_idx, top_k_logprobs_val = _get_top_logprobs(probs, k=20) |
| def _get_top_logprobs(probs: torch.Tensor, k: int = 20): | ||
| top_k_logprobs_val, top_k_logprobs_idx = torch.topk(probs, k=k, dim=-1) | ||
| top_k_logprobs_val = torch.log(top_k_logprobs_val) | ||
| return top_k_logprobs_val, top_k_logprobs_idx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with how the values are used in the sample function, it would be clearer if this function returned the indices first, then the values (idx, val). The sample function returns (..., idx, val), so aligning this helper's return signature would improve readability and reduce confusion.
| return top_k_logprobs_val, top_k_logprobs_idx | |
| return top_k_logprobs_idx, top_k_logprobs_val |
|
Thank you for your contribution. We will review and test it promptly. |
Add top_logprobs support
Came across this repo while messing around on github, I am an intern on aws sage maker jumpstart this fall and wanting to go into the lower level side of things!
What This PR Does
Implements the top_logprobs feature that was marked as TODO in the codebase. Now when the model generates text, instead of just returning the winning token, it returns the top 20 most likely candidates at each position along with their log probabilities.
What Changed
Memory Layer (req.py)
MAX_TOP_K_LOGPROBS = 20(matches OpenAI's API limit)shm_top_logprobs_ids,shm_top_logprobs_val) to store top-20 candidates[sequence_length, 20]for both IDs and valuesGPU/Sampling Layer (generic_post_process.py)
torch.topk()on the original probability distributionTransport Layer (base_backend.py, infer_batch.py)
API Layer (manager.py, api_openai.py)
Backend Integration (
dp_backend/impl.py)Why Max 20?
Hardcoded to match OpenAI's API specification:
Users can request less (e.g.,
top_logprobs=5), but we always allocate and calculate 20 for simplicity. The API layer filters to the requested amount.Why Triton Backend Only?
Currently implemented for the Triton sampling backend because it gives us access to probability tensors in Python. The
sglang_kernelbackend uses fused CUDA kernels where all sampling happens in one black box GPU call, making it harder to extract intermediate probabilities without modifying the kernel itself.Testing
Manually tested the code changes on macOS (without GPU)