Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,12 +814,24 @@ def _meet_max_token_stop_criteria(self, request: LlmRequest):
>= self.max_seq_len)

@staticmethod
def _meet_stop_token_criteria(request: LlmRequest):
def _meet_stop_token_criteria(request: LlmRequest, new_token: int):
if request.py_stop_words_list:
assert isinstance(
request.py_stop_words_list,
list), "request.py_stop_words_list should be a list"
stop_words_list, prefix_sum = request.py_stop_words_list

# Determine max stop word length to decide optimization path
max_stop_word_length = prefix_sum[0] if prefix_sum else 0
for i in range(1, len(prefix_sum)):
word_length = prefix_sum[i] - prefix_sum[i - 1]
max_stop_word_length = max(max_stop_word_length, word_length)

# Fast path: all stop words are single tokens
if max_stop_word_length == 1:
return new_token in stop_words_list

# Slow path: at least one multi-token stop word exists
tokens = request.get_tokens(0)
offset = 0
for i, offset_end in enumerate(prefix_sum):
Expand All @@ -844,7 +856,7 @@ def _handle_stop_criteria(self, request: LlmRequest,
request.finish_by(FinishReason.LENGTH, self.BEAM)
return True

if self._meet_stop_token_criteria(request):
if self._meet_stop_token_criteria(request, new_token):
request.finish_by(FinishReason.STOP_WORDS, self.BEAM)
return True

Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_a10.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ l0_a10:
# ------------- PyTorch tests ---------------
- unittest/_torch/modeling/test_modeling_mistral.py
- unittest/_torch/modeling/test_modeling_pixtral.py
- unittest/_torch/sampler/test_trtllm_sampler.py
# NOTE: this is a CPU-only test, but we do not have a dedicated job for this (and therefore no
# test list either).
- unittest/_torch/models/checkpoints/hf/test_weight_loader.py
Expand Down
88 changes: 83 additions & 5 deletions tests/unittest/_torch/sampler/test_trtllm_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ def model_path():
return llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"


def create_llm(model_dir):
"""Create LLM with specific overlap scheduler setting"""
def _create_llm_base(model_dir, enable_trtllm_sampler):
"""Base LLM creation with configurable sampler."""
sampler_type = "TRTLLMSampler" if enable_trtllm_sampler else "TorchSampler"

trt_kv_cache_config = TRT_KvCacheConfig(enable_block_reuse=False)

return LLM(
Expand All @@ -22,10 +24,20 @@ def create_llm(model_dir):
trust_remote_code=True,
enable_chunked_prefill=True,
cuda_graph_config=CudaGraphConfig(),
sampler_type=sampler_type,
kv_cache_config=trt_kv_cache_config,
max_num_tokens=
128 # Only one request longer than max_num_tokens is required to test chunked prefill
)
max_num_tokens=128
) # Only one request longer than max_num_tokens is required to test chunked prefill


def create_llm(model_dir):
"""Create LLM with specific overlap scheduler setting"""
return _create_llm_base(model_dir, enable_trtllm_sampler=True)


def create_llm_with_torch_sampler(model_dir):
"""Create LLM with TorchSampler."""
return _create_llm_base(model_dir, enable_trtllm_sampler=False)


@pytest.mark.high_cuda_memory
Expand Down Expand Up @@ -67,3 +79,69 @@ def test_trtllm_sampler(model_path):
# Verify outputs are consistent
for text, expected in zip(texts, expected_outputs):
assert similar(text, expected), f"text: {text}, expected: {expected}"


@pytest.mark.high_cuda_memory
def test_trtllm_sampler_with_stop_token_ids(model_path):
"""Test sampler with stop_token_ids (fast path optimization)."""

llm = create_llm_with_torch_sampler(model_path)
tokenizer = llm.tokenizer

prompt = "The capital of France is"
target_sentence = "The capital of France is Paris"

prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
target_tokens = tokenizer.encode(target_sentence, add_special_tokens=False)

# Use the first token after the prompt as the stop token
assert len(target_tokens) > len(
prompt_tokens), "Target must be longer than prompt"
stop_token_id = target_tokens[len(prompt_tokens)]

sampling_config = SamplingParams(max_tokens=100,
n=1,
stop_token_ids=[stop_token_id],
temperature=0.0)

outputs = llm.generate([prompt], sampling_params=sampling_config)
text = outputs[0].outputs[0].text

output_tokens = tokenizer.encode(text, add_special_tokens=False)

llm.shutdown()
assert stop_token_id not in output_tokens, f"Output should not contain stop token {stop_token_id}"
assert len(output_tokens
) < 10, "Should stop very early with first-token stop_token_id"


@pytest.mark.high_cuda_memory
def test_torch_sampler_with_multi_token_stop_words(model_path):
"""Test TorchSampler with multi-token stop words (slow path)."""

llm = create_llm_with_torch_sampler(model_path)
tokenizer = llm.tokenizer

prompt = "The capital of France is"

# Use a string that will tokenize to multiple tokens
stop_string = "\n\n"
stop_tokens = tokenizer.encode(stop_string, add_special_tokens=False)

assert len(
stop_tokens
) > 1, f"Stop string should be multi-token, got {len(stop_tokens)} tokens"

sampling_config = SamplingParams(
max_tokens=100,
n=1,
stop=[stop_string], # Use 'stop' parameter for multi-token
temperature=0.0)

outputs = llm.generate([prompt], sampling_params=sampling_config)
text = outputs[0].outputs[0].text

llm.shutdown()

assert len(text) > 0, "Should generate some text"
assert stop_string not in text, f"Stop string '{repr(stop_string)}' should not appear in the output"