Skip to content

Conversation

@ahmadbasyouni10
Copy link

@ahmadbasyouni10 ahmadbasyouni10 commented Nov 24, 2025

Add top_logprobs support

Came across this repo while messing around on github, I am an intern on aws sage maker jumpstart this fall and wanting to go into the lower level side of things!

What This PR Does

Implements the top_logprobs feature that was marked as TODO in the codebase. Now when the model generates text, instead of just returning the winning token, it returns the top 20 most likely candidates at each position along with their log probabilities.

What Changed

Memory Layer (req.py)

  • Added MAX_TOP_K_LOGPROBS = 20 (matches OpenAI's API limit)
  • Created shared memory structures (shm_top_logprobs_ids, shm_top_logprobs_val) to store top-20 candidates
  • Used C structs to allocate 2D arrays: [sequence_length, 20] for both IDs and values

GPU/Sampling Layer (generic_post_process.py)

  • Added _get_top_logprobs() helper function
  • Modified sample() to calculate top-20 using torch.topk() on the original probability distribution
  • Design decision: Created a new function instead of refactoring _top_p_top_k() because that function destructively modifies probabilities (sets filtered tokens to 0.0) for sampling purposes, but we need the raw probabilities for accurate reporting

Transport Layer (base_backend.py, infer_batch.py)

  • Added GPU→CPU data transfer for top-k arrays
  • Modified _async_copy_next_token_infos_to_pin_mem() to handle the new data
  • Updated set_next_gen_token_id() to write top 20 data to shared memory

API Layer (manager.py, api_openai.py)

  • HttpServer reads top k data from shared memory
  • Formats token IDs to text using the tokenizer
  • Returns in OpenAI compatible format

Backend Integration (dp_backend/impl.py)

  • Updated all data parallel code paths to thread the new parameters through

Why Max 20?

Hardcoded to match OpenAI's API specification:

"An integer between 0 and 20 specifying the number of most likely tokens to return..."

Users can request less (e.g., top_logprobs=5), but we always allocate and calculate 20 for simplicity. The API layer filters to the requested amount.

Why Triton Backend Only?

Currently implemented for the Triton sampling backend because it gives us access to probability tensors in Python. The sglang_kernel backend uses fused CUDA kernels where all sampling happens in one black box GPU call, making it harder to extract intermediate probabilities without modifying the kernel itself.

Testing

Manually tested the code changes on macOS (without GPU)

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @ahmadbasyouni10, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces the top_logprobs feature, enabling the system to return the top 20 most probable tokens and their log probabilities at each generation step, aligning with OpenAI's API specification. The implementation involves significant changes across the memory, GPU sampling, data transport, and API layers to manage, calculate, transfer, and expose this new data. It leverages shared memory for efficient data handling and is currently supported only on the Triton sampling backend.

Highlights

  • Introduced top_logprobs Feature: Added support for returning the top 20 most likely token candidates and their log probabilities during text generation, fulfilling a previous TODO in the codebase.
  • Shared Memory Integration: Implemented new shared memory structures (shm_top_logprobs_ids, shm_top_logprobs_val) using C structs to efficiently store 2D arrays of [sequence_length, 20] for token IDs and their corresponding log probability values.
  • Sampling Layer Enhancements: Modified the sample() function to calculate top-20 log probabilities using torch.topk() on the raw probability distribution. A new helper function _get_top_logprobs() was introduced for this purpose, avoiding destructive modifications to probabilities used for sampling.
  • Data Transfer Mechanism: Updated the transport layer to handle GPU-to-CPU data transfer for the new top-k arrays, ensuring this information is propagated correctly through the system to the API layer.
  • API Compatibility: The API layer now reads and formats the top-k data from shared memory, converting token IDs to text using the tokenizer, and returns it in an OpenAI-compatible format. The system hardcodes a maximum of 20 log probabilities, filtering to the user's requested amount.
  • Triton Backend Specific: The top_logprobs functionality is currently limited to the Triton sampling backend because it provides direct access to probability tensors in Python, which is necessary for extracting intermediate probabilities.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR successfully implements the top_logprobs feature. The changes are well-structured and span across the memory, GPU, transport, and API layers to support this new functionality. The implementation correctly uses shared memory for inter-process communication and handles data transfer from GPU to CPU. My review includes a few suggestions to improve code maintainability by reducing duplication and to enhance performance in the API formatting layer. I've also pointed out a minor inconsistency in return value ordering for better code clarity. Overall, this is a solid implementation.

Comment on lines +709 to +713
formatted_top_logprobs = {}
for item in top_logprobs:
for t_id, t_prob in item.items():
t_text = tokenizer.decode([t_id], skip_special_tokens=False)
formatted_top_logprobs[t_text] = t_prob
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation decodes token IDs one by one inside a loop. This can be inefficient, especially for a large number of top logprobs. You can improve performance by collecting all token IDs and using tokenizer.batch_decode to process them in a single call.

Suggested change
formatted_top_logprobs = {}
for item in top_logprobs:
for t_id, t_prob in item.items():
t_text = tokenizer.decode([t_id], skip_special_tokens=False)
formatted_top_logprobs[t_text] = t_prob
token_ids = [next(iter(item)) for item in top_logprobs]
log_probs = [next(iter(item.values())) for item in top_logprobs]
token_texts = tokenizer.batch_decode(
[[tid] for tid in token_ids], skip_special_tokens=False
)
formatted_top_logprobs = dict(zip(token_texts, log_probs))

Comment on lines +224 to +251
def create_top_logprobs_shm_array(self):
service_uni_name = get_unique_server_name()
name_ids = f"{service_uni_name}_shm_top_logprobs_ids_{self.index_in_shm_mem}"
self.shm_top_logprobs_ids = ShmArray(name_ids, (self.alloc_shm_numpy_len, MAX_TOP_K_LOGPROBS), dtype=np.int32)
self.shm_top_logprobs_ids.create_shm()

name_val = f"{service_uni_name}_shm_top_logprobs_val_{self.index_in_shm_mem}"
self.shm_top_logprobs_val = ShmArray(name_val, (self.alloc_shm_numpy_len, MAX_TOP_K_LOGPROBS), dtype=np.float32)
self.shm_top_logprobs_val.create_shm()
return

def link_logprobs_shm_array(self):
service_uni_name = get_unique_server_name()
name = f"{service_uni_name}_shm_logprobs_{self.index_in_shm_mem}"
self.shm_logprobs = ShmArray(name, (self.alloc_shm_numpy_len,), dtype=np.float32)
self.shm_logprobs.link_shm()
return

def link_top_logprobs_shm_array(self):
service_uni_name = get_unique_server_name()
name_ids = f"{service_uni_name}_shm_top_logprobs_ids_{self.index_in_shm_mem}"
self.shm_top_logprobs_ids = ShmArray(name_ids, (self.alloc_shm_numpy_len, MAX_TOP_K_LOGPROBS), dtype=np.int32)
self.shm_top_logprobs_ids.link_shm()

name_val = f"{service_uni_name}_shm_top_logprobs_val_{self.index_in_shm_mem}"
self.shm_top_logprobs_val = ShmArray(name_val, (self.alloc_shm_numpy_len, MAX_TOP_K_LOGPROBS), dtype=np.float32)
self.shm_top_logprobs_val.link_shm()
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The methods create_top_logprobs_shm_array and link_top_logprobs_shm_array contain a lot of duplicated code. The only difference is calling create_shm() vs link_shm(). This could be refactored into a helper method to reduce redundancy and improve maintainability. For example, a private helper method could handle the initialization of the ShmArray objects, while the public methods would just call the appropriate create_shm() or link_shm() on them.

next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index)
next_token_logprobs = torch.log(torch.gather(probs_sort, dim=1, index=sampled_index))
return next_token_ids.view(-1), next_token_logprobs.view(-1)
top_k_logprobs_val, top_k_logprobs_idx = _get_top_logprobs(probs, k=20)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve clarity and consistency with the proposed change in _get_top_logprobs, the variable names here should be swapped to reflect the idx, val order. This makes the data flow from the helper function to the return statement of this function more intuitive.

Suggested change
top_k_logprobs_val, top_k_logprobs_idx = _get_top_logprobs(probs, k=20)
top_k_logprobs_idx, top_k_logprobs_val = _get_top_logprobs(probs, k=20)

def _get_top_logprobs(probs: torch.Tensor, k: int = 20):
top_k_logprobs_val, top_k_logprobs_idx = torch.topk(probs, k=k, dim=-1)
top_k_logprobs_val = torch.log(top_k_logprobs_val)
return top_k_logprobs_val, top_k_logprobs_idx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with how the values are used in the sample function, it would be clearer if this function returned the indices first, then the values (idx, val). The sample function returns (..., idx, val), so aligning this helper's return signature would improve readability and reduce confusion.

Suggested change
return top_k_logprobs_val, top_k_logprobs_idx
return top_k_logprobs_idx, top_k_logprobs_val

@shihaobai
Copy link
Collaborator

Thank you for your contribution. We will review and test it promptly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants