feat: support invalid_token_ids in sampling params#1305
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements a feature to mask specific token IDs by setting their logits to negative infinity during the sampling stage. It introduces a new Triton kernel for efficient masking and updates the sampling parameter objects to handle these IDs. Feedback from the review identifies a critical syntax error in the parameter initialization and a logic error where all logit_bias entries were incorrectly treated as tokens to be masked. Additionally, the reviewer recommended using the existing pinned memory manager for better performance and suggested increasing the default limit on the number of invalid tokens.
| invalid_token_ids = map(int, kwargs.get("logit_bias", {}).keys()) | ||
| self.invalid_token_ids = InvalidTokenIds() | ||
| self.invalid_token_ids.initialize(list[int](invalid_token_ids)) |
There was a problem hiding this comment.
This initialization logic has several critical issues:
- Syntax Error:
list[int](invalid_token_ids)will raise aTypeErrorbecauselist[int]is aGenericAlias(type hint) and is not a callable constructor. It should belist(invalid_token_ids). - Logic Error: The code populates
invalid_token_idsby masking all keys present inlogit_biasto-inf. However,logit_biasis typically used for both boosting (positive values) and suppressing (negative values) tokens. This implementation incorrectly invalidates tokens that the user intended to boost. - Field Inconsistency: The
invalid_token_idsfield added topy_sampling_params.pyis ignored here, as the code only looks atlogit_biaskeys.
| invalid_token_ids = map(int, kwargs.get("logit_bias", {}).keys()) | |
| self.invalid_token_ids = InvalidTokenIds() | |
| self.invalid_token_ids.initialize(list[int](invalid_token_ids)) | |
| # Initialize invalid_token_ids from the dedicated field or suppressed logit_bias | |
| invalid_token_ids = kwargs.get("invalid_token_ids") or [] | |
| if not invalid_token_ids: | |
| logit_bias = kwargs.get("logit_bias") or {} | |
| invalid_token_ids = [int(k) for k, v in logit_bias.items() if v <= -100] | |
| self.invalid_token_ids = InvalidTokenIds() | |
| self.invalid_token_ids.initialize(list(invalid_token_ids)) |
| REGULAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_REGULAR_CONSTRAINT_MAX_LENGTH", 2048)) | ||
| GRAMMAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_GRAMMAR_CONSTRAINT_MAX_LENGTH", 2048)) | ||
| JSON_SCHEMA_MAX_LENGTH = int(os.getenv("LIGHTLLM_JSON_SCHEMA_MAX_LENGTH", 2048)) | ||
| INVALID_TOKEN_IDS_MAX_LENGTH = int(os.getenv("LIGHTLLM_INVALID_TOKEN_IDS_MAX_LENGTH", 10)) |
| invalid_token_ids_cpu = torch.tensor(invalid_token_ids, dtype=torch.int32, device="cpu", pin_memory=True) | ||
| cu_invalid_token_num_cpu = torch.tensor(cu_invalid_token_num, dtype=torch.int32, device="cpu", pin_memory=True) |
There was a problem hiding this comment.
Creating new pinned memory tensors for invalid_token_ids and cu_invalid_token_num on every sampling step is inefficient due to the overhead of allocation and pinning. It is better to use g_pin_mem_manager.gen_from_list, which reuses pre-allocated pinned memory buffers, consistent with how other sampling parameters are handled in this function.
| invalid_token_ids_cpu = torch.tensor(invalid_token_ids, dtype=torch.int32, device="cpu", pin_memory=True) | |
| cu_invalid_token_num_cpu = torch.tensor(cu_invalid_token_num, dtype=torch.int32, device="cpu", pin_memory=True) | |
| invalid_token_ids_cpu = g_pin_mem_manager.gen_from_list(key="invalid_token_ids", data=invalid_token_ids, dtype=torch.int32) | |
| cu_invalid_token_num_cpu = g_pin_mem_manager.gen_from_list(key="cu_invalid_token_num", data=cu_invalid_token_num, dtype=torch.int32) |
- Add InvalidTokenIds ctypes struct and shm field on SamplingParams, populated from request `logit_bias` keys. - Plumb invalid_token_ids through py SamplingParams and InferSamplingParams, including vocab_size validation. - Add apply_invalid_token_ids Triton kernel that masks given token ids to -inf, applied during sampling between penalty application and softmax. - Move apply_penalty.py and apply_penalty_gpu_cache.py into a new triton_kernel/post_process/ subdirectory and add the new kernel there. - Add unit test for the new kernel.
Hits /generate twice with a logit_bias map covering common English tokens (via the Qwen3.5 tokenizer) and asserts none of the blocked ids appear in the biased output, while the baseline produces them.
bc90393 to
ceebefd
Compare
logit_biaskeys.