diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index a3a0e1b15c4..415626836d8 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -814,12 +814,24 @@ def _meet_max_token_stop_criteria(self, request: LlmRequest): >= self.max_seq_len) @staticmethod - def _meet_stop_token_criteria(request: LlmRequest): + def _meet_stop_token_criteria(request: LlmRequest, new_token: int): if request.py_stop_words_list: assert isinstance( request.py_stop_words_list, list), "request.py_stop_words_list should be a list" stop_words_list, prefix_sum = request.py_stop_words_list + + # Determine max stop word length to decide optimization path + max_stop_word_length = prefix_sum[0] if prefix_sum else 0 + for i in range(1, len(prefix_sum)): + word_length = prefix_sum[i] - prefix_sum[i - 1] + max_stop_word_length = max(max_stop_word_length, word_length) + + # Fast path: all stop words are single tokens + if max_stop_word_length == 1: + return new_token in stop_words_list + + # Slow path: at least one multi-token stop word exists tokens = request.get_tokens(0) offset = 0 for i, offset_end in enumerate(prefix_sum): @@ -844,7 +856,7 @@ def _handle_stop_criteria(self, request: LlmRequest, request.finish_by(FinishReason.LENGTH, self.BEAM) return True - if self._meet_stop_token_criteria(request): + if self._meet_stop_token_criteria(request, new_token): request.finish_by(FinishReason.STOP_WORDS, self.BEAM) return True diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 37aac5a6b87..a5fb80ec88c 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -16,6 +16,7 @@ l0_a10: # ------------- PyTorch tests --------------- - unittest/_torch/modeling/test_modeling_mistral.py - unittest/_torch/modeling/test_modeling_pixtral.py + - unittest/_torch/sampler/test_trtllm_sampler.py # NOTE: this is a CPU-only test, but we do not have a dedicated job for this (and therefore no # test list either). - unittest/_torch/models/checkpoints/hf/test_weight_loader.py diff --git a/tests/unittest/_torch/sampler/test_trtllm_sampler.py b/tests/unittest/_torch/sampler/test_trtllm_sampler.py index 37227f9b53f..25510ae2f5f 100644 --- a/tests/unittest/_torch/sampler/test_trtllm_sampler.py +++ b/tests/unittest/_torch/sampler/test_trtllm_sampler.py @@ -12,8 +12,10 @@ def model_path(): return llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" -def create_llm(model_dir): - """Create LLM with specific overlap scheduler setting""" +def _create_llm_base(model_dir, enable_trtllm_sampler): + """Base LLM creation with configurable sampler.""" + sampler_type = "TRTLLMSampler" if enable_trtllm_sampler else "TorchSampler" + trt_kv_cache_config = TRT_KvCacheConfig(enable_block_reuse=False) return LLM( @@ -22,10 +24,20 @@ def create_llm(model_dir): trust_remote_code=True, enable_chunked_prefill=True, cuda_graph_config=CudaGraphConfig(), + sampler_type=sampler_type, kv_cache_config=trt_kv_cache_config, - max_num_tokens= - 128 # Only one request longer than max_num_tokens is required to test chunked prefill - ) + max_num_tokens=128 + ) # Only one request longer than max_num_tokens is required to test chunked prefill + + +def create_llm(model_dir): + """Create LLM with specific overlap scheduler setting""" + return _create_llm_base(model_dir, enable_trtllm_sampler=True) + + +def create_llm_with_torch_sampler(model_dir): + """Create LLM with TorchSampler.""" + return _create_llm_base(model_dir, enable_trtllm_sampler=False) @pytest.mark.high_cuda_memory @@ -67,3 +79,69 @@ def test_trtllm_sampler(model_path): # Verify outputs are consistent for text, expected in zip(texts, expected_outputs): assert similar(text, expected), f"text: {text}, expected: {expected}" + + +@pytest.mark.high_cuda_memory +def test_trtllm_sampler_with_stop_token_ids(model_path): + """Test sampler with stop_token_ids (fast path optimization).""" + + llm = create_llm_with_torch_sampler(model_path) + tokenizer = llm.tokenizer + + prompt = "The capital of France is" + target_sentence = "The capital of France is Paris" + + prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False) + target_tokens = tokenizer.encode(target_sentence, add_special_tokens=False) + + # Use the first token after the prompt as the stop token + assert len(target_tokens) > len( + prompt_tokens), "Target must be longer than prompt" + stop_token_id = target_tokens[len(prompt_tokens)] + + sampling_config = SamplingParams(max_tokens=100, + n=1, + stop_token_ids=[stop_token_id], + temperature=0.0) + + outputs = llm.generate([prompt], sampling_params=sampling_config) + text = outputs[0].outputs[0].text + + output_tokens = tokenizer.encode(text, add_special_tokens=False) + + llm.shutdown() + assert stop_token_id not in output_tokens, f"Output should not contain stop token {stop_token_id}" + assert len(output_tokens + ) < 10, "Should stop very early with first-token stop_token_id" + + +@pytest.mark.high_cuda_memory +def test_torch_sampler_with_multi_token_stop_words(model_path): + """Test TorchSampler with multi-token stop words (slow path).""" + + llm = create_llm_with_torch_sampler(model_path) + tokenizer = llm.tokenizer + + prompt = "The capital of France is" + + # Use a string that will tokenize to multiple tokens + stop_string = "\n\n" + stop_tokens = tokenizer.encode(stop_string, add_special_tokens=False) + + assert len( + stop_tokens + ) > 1, f"Stop string should be multi-token, got {len(stop_tokens)} tokens" + + sampling_config = SamplingParams( + max_tokens=100, + n=1, + stop=[stop_string], # Use 'stop' parameter for multi-token + temperature=0.0) + + outputs = llm.generate([prompt], sampling_params=sampling_config) + text = outputs[0].outputs[0].text + + llm.shutdown() + + assert len(text) > 0, "Should generate some text" + assert stop_string not in text, f"Stop string '{repr(stop_string)}' should not appear in the output"