Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from copy import deepcopy
from copy import copy, deepcopy
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

Expand Down Expand Up @@ -327,7 +327,8 @@ def generation_logits(self) -> torch.Tensor | None:

@property
def log_probs(self) -> list[TokenLogprobs] | None:
return self._log_probs and self._log_probs.log_probs
return self._log_probs and hasattr(
self._log_probs, 'log_probs') and self._log_probs.log_probs

@property
def cum_log_probs(self) -> list[float] | None:
Expand Down Expand Up @@ -589,10 +590,21 @@ def create_response(self,
"""
result, is_final = super().create_serialized_result(
use_fast_logits, mpi_world_rank)

# Performs a deep copy of py_result._log_probs to eliminate race conditions that may occur between IPC communication and the overriding of newly generated log_probs in streaming mode.
if self.streaming and self.py_result.log_probs and self.sampling_config.beam_width <= 1:
py_result = copy(self.py_result)
py_result._log_probs = deepcopy(self.py_result._log_probs)

for log_prob in self.py_result.log_probs:
log_prob.clear()
else:
py_result = self.py_result

return LlmResponse(
request_id=self.py_request_id
if self.is_child else self.parent_request_id,
result=LlmResult(result, self.py_result, is_final),
result=LlmResult(result, py_result, is_final),
client_id=self.py_client_id) if len(result) > 0 else None

@property
Expand Down
25 changes: 17 additions & 8 deletions tensorrt_llm/executor/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ def __init__(self,
self._done = False
self.metrics_dict = {}
self.trace_headers: Optional[dict[str, str]] = None
# torch backend will use trtllm sampler in beam search mode, but it does not support return logprobs incrementally
self.use_trtllm_sampler = sampling_params.use_beam_search and sampling_params.best_of > 1

if ray_queue is not None:
if has_event_loop():
Expand Down Expand Up @@ -378,20 +380,27 @@ def _handle_sequence(self,
# each streamed response_tensors.log_probs[src_idx]
# contains a streamwise monotonically growing list of logprobs.
# so we need to accumulate only the new ones unique to that particular streamed response
assert output._last_logprobs_len <= len(
response_tensors.log_probs[src_idx]
), (f"_last_logprobs_len ({output._last_logprobs_len}) > log_probs length ("
f"{len(response_tensors.log_probs[src_idx])})")
output.logprobs += response_tensors.log_probs[src_idx][
output._last_logprobs_len:]
if self.use_trtllm_sampler:
assert output._last_logprobs_len <= len(
response_tensors.log_probs[src_idx]
), (f"_last_logprobs_len ({output._last_logprobs_len}) > log_probs length ("
f"{len(response_tensors.log_probs[src_idx])})")
output.logprobs += response_tensors.log_probs[src_idx][
output._last_logprobs_len:]
else:
output.logprobs += response_tensors.log_probs[src_idx]

# overcome some WAR in the cpp executor
if finish_reasons[src_idx] != tllm.FinishReason.CANCELLED:
if finish_reasons[
src_idx] != tllm.FinishReason.CANCELLED and self.use_trtllm_sampler:
# Check if logprobs is a list (not a dict or other structure)
if len(output.logprobs) > output.length:
# LlmResult holds a reference to LogProbStorage, which may be updated by the worker before the result is serialized.
# Therefore, we treat extra logprobs/logits as expected and only consume what's needed.
output.logprobs = output.logprobs[:output.length]
assert len(output.logprobs) == output.length
assert len(
output.logprobs
) == output.length, f"logprobs length: {len(output.logprobs)} != output.length: {output.length}"

if response_tensors.generation_logits is not None:
output.generation_logits = response_tensors.generation_logits[
Expand Down