Skip to content

feat: support invalid_token_ids in sampling params#1305

Merged
shihaobai merged 3 commits into
mainfrom
feat/invalid-token-ids
May 11, 2026
Merged

feat: support invalid_token_ids in sampling params#1305
shihaobai merged 3 commits into
mainfrom
feat/invalid-token-ids

Conversation

@shihaobai
Copy link
Copy Markdown
Collaborator

  • Add InvalidTokenIds ctypes struct and shm field on SamplingParams, populated from request logit_bias keys.
  • Plumb invalid_token_ids through py SamplingParams and InferSamplingParams, including vocab_size validation.
  • Add apply_invalid_token_ids Triton kernel that masks given token ids to -inf, applied during sampling between penalty application and softmax.
  • Move apply_penalty.py and apply_penalty_gpu_cache.py into a new triton_kernel/post_process/ subdirectory and add the new kernel there.
  • Add unit test for the new kernel.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a feature to mask specific token IDs by setting their logits to negative infinity during the sampling stage. It introduces a new Triton kernel for efficient masking and updates the sampling parameter objects to handle these IDs. Feedback from the review identifies a critical syntax error in the parameter initialization and a logic error where all logit_bias entries were incorrectly treated as tokens to be masked. Additionally, the reviewer recommended using the existing pinned memory manager for better performance and suggested increasing the default limit on the number of invalid tokens.

Comment on lines +420 to +422
invalid_token_ids = map(int, kwargs.get("logit_bias", {}).keys())
self.invalid_token_ids = InvalidTokenIds()
self.invalid_token_ids.initialize(list[int](invalid_token_ids))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This initialization logic has several critical issues:

  1. Syntax Error: list[int](invalid_token_ids) will raise a TypeError because list[int] is a GenericAlias (type hint) and is not a callable constructor. It should be list(invalid_token_ids).
  2. Logic Error: The code populates invalid_token_ids by masking all keys present in logit_bias to -inf. However, logit_bias is typically used for both boosting (positive values) and suppressing (negative values) tokens. This implementation incorrectly invalidates tokens that the user intended to boost.
  3. Field Inconsistency: The invalid_token_ids field added to py_sampling_params.py is ignored here, as the code only looks at logit_bias keys.
Suggested change
invalid_token_ids = map(int, kwargs.get("logit_bias", {}).keys())
self.invalid_token_ids = InvalidTokenIds()
self.invalid_token_ids.initialize(list[int](invalid_token_ids))
# Initialize invalid_token_ids from the dedicated field or suppressed logit_bias
invalid_token_ids = kwargs.get("invalid_token_ids") or []
if not invalid_token_ids:
logit_bias = kwargs.get("logit_bias") or {}
invalid_token_ids = [int(k) for k, v in logit_bias.items() if v <= -100]
self.invalid_token_ids = InvalidTokenIds()
self.invalid_token_ids.initialize(list(invalid_token_ids))

REGULAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_REGULAR_CONSTRAINT_MAX_LENGTH", 2048))
GRAMMAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_GRAMMAR_CONSTRAINT_MAX_LENGTH", 2048))
JSON_SCHEMA_MAX_LENGTH = int(os.getenv("LIGHTLLM_JSON_SCHEMA_MAX_LENGTH", 2048))
INVALID_TOKEN_IDS_MAX_LENGTH = int(os.getenv("LIGHTLLM_INVALID_TOKEN_IDS_MAX_LENGTH", 10))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The limit of 10 invalid token IDs is quite restrictive for many use cases. Consider increasing INVALID_TOKEN_IDS_MAX_LENGTH to a more reasonable value (e.g., 64 or 128) to provide more flexibility for users who need to mask a larger set of tokens.

Comment on lines +210 to +211
invalid_token_ids_cpu = torch.tensor(invalid_token_ids, dtype=torch.int32, device="cpu", pin_memory=True)
cu_invalid_token_num_cpu = torch.tensor(cu_invalid_token_num, dtype=torch.int32, device="cpu", pin_memory=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Creating new pinned memory tensors for invalid_token_ids and cu_invalid_token_num on every sampling step is inefficient due to the overhead of allocation and pinning. It is better to use g_pin_mem_manager.gen_from_list, which reuses pre-allocated pinned memory buffers, consistent with how other sampling parameters are handled in this function.

Suggested change
invalid_token_ids_cpu = torch.tensor(invalid_token_ids, dtype=torch.int32, device="cpu", pin_memory=True)
cu_invalid_token_num_cpu = torch.tensor(cu_invalid_token_num, dtype=torch.int32, device="cpu", pin_memory=True)
invalid_token_ids_cpu = g_pin_mem_manager.gen_from_list(key="invalid_token_ids", data=invalid_token_ids, dtype=torch.int32)
cu_invalid_token_num_cpu = g_pin_mem_manager.gen_from_list(key="cu_invalid_token_num", data=cu_invalid_token_num, dtype=torch.int32)

shihaobai added 2 commits May 11, 2026 12:17
- Add InvalidTokenIds ctypes struct and shm field on SamplingParams,
  populated from request `logit_bias` keys.
- Plumb invalid_token_ids through py SamplingParams and InferSamplingParams,
  including vocab_size validation.
- Add apply_invalid_token_ids Triton kernel that masks given token ids to
  -inf, applied during sampling between penalty application and softmax.
- Move apply_penalty.py and apply_penalty_gpu_cache.py into a new
  triton_kernel/post_process/ subdirectory and add the new kernel there.
- Add unit test for the new kernel.
Hits /generate twice with a logit_bias map covering common English tokens
(via the Qwen3.5 tokenizer) and asserts none of the blocked ids appear in
the biased output, while the baseline produces them.
@shihaobai shihaobai force-pushed the feat/invalid-token-ids branch from bc90393 to ceebefd Compare May 11, 2026 12:23
@shihaobai shihaobai merged commit 8bcd28b into main May 11, 2026
1 check passed
@shihaobai shihaobai deleted the feat/invalid-token-ids branch May 11, 2026 12:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant