Skip to content
4 changes: 3 additions & 1 deletion cpp/tensorrt_llm/batch_manager/llmRequest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ void LlmRequest::createSerializedResult(
/// Note that there is some dependency on the order of operations in this method. Modify with care!
std::optional<executor::Result> LlmRequest::createResult(bool useFastLogits, int32_t mpiWorldRank)
{
if (!(isFinished() || (mIsStreaming && mState == LlmRequestState::kGENERATION_IN_PROGRESS)))
auto const streamingInProgress = mIsStreaming
&& (mState == LlmRequestState::kGENERATION_IN_PROGRESS || mState == LlmRequestState::kGENERATION_TO_COMPLETE);
if (!(isFinished() || streamingInProgress))
{
return std::nullopt;
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/nanobind/batch_manager/algorithms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void tensorrt_llm::nanobind::batch_manager::algorithms::initBindings(nb::module_
LlmRequestState>(),
nb::arg("ctx_chunk_config") = std::nullopt, nb::arg("max_context_length") = std::nullopt,
nb::arg("no_schedule_until_state") = LlmRequestState::kCONTEXT_INIT,
nb::arg("no_schedule_after_state") = LlmRequestState::kGENERATION_COMPLETE)
nb::arg("no_schedule_after_state") = LlmRequestState::kGENERATION_TO_COMPLETE)
.def("__call__", &MicroBatchScheduler::operator(), nb::arg("active_requests"), nb::arg("inflight_req_ids"),
nb::arg("max_batch_size_runtime"), nb::arg("max_num_tokens_runtime"))
.def("name", [](MicroBatchScheduler const&) { return MicroBatchScheduler::name; });
Expand Down
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ void initBindings(nb::module_& m)
.def("get_last_tokens", nb::overload_cast<>(&GenLlmReq::getLastTokens))
.def("get_beam_width_by_iter", &GenLlmReq::getBeamWidthByIter, nb::arg("for_next_iteration") = false)
.def_prop_ro("max_num_generated_tokens", &GenLlmReq::getMaxNumGeneratedTokens)
.def("will_complete_next_iteration", &GenLlmReq::willCompleteNextIteration)
.def("add_new_token", &GenLlmReq::addNewToken, nb::arg("token"), nb::arg("beam"))
.def("add_new_tokens", &GenLlmReq::addNewTokens, nb::arg("beam_tokens"))
.def_prop_ro("num_draft_tokens", &GenLlmReq::getNumDraftTokens)
Expand Down
4 changes: 2 additions & 2 deletions cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::mod
LlmRequestState>(),
py::arg("ctx_chunk_config") = std::nullopt, py::arg("max_context_length") = std::nullopt,
py::arg_v("no_schedule_until_state", LlmRequestState::kCONTEXT_INIT, "LlmRequestState.CONTEXT_INIT"),
py::arg_v("no_schedule_after_state", LlmRequestState::kGENERATION_COMPLETE,
"LlmRequestState.GENERATION_COMPLETE"))
py::arg_v("no_schedule_after_state", LlmRequestState::kGENERATION_TO_COMPLETE,
"LlmRequestState.GENERATION_TO_COMPLETE"))
.def("__call__", &MicroBatchScheduler::operator(), py::arg("active_requests"), py::arg("inflight_req_ids"),
py::arg("max_batch_size_runtime"), py::arg("max_num_tokens_runtime"))
.def("name", [](MicroBatchScheduler const&) { return MicroBatchScheduler::name; });
Expand Down
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ void initBindings(pybind11::module_& m)
.def("get_last_tokens", py::overload_cast<>(&GenLlmReq::getLastTokens))
.def("get_beam_width_by_iter", &GenLlmReq::getBeamWidthByIter, py::arg("for_next_iteration") = false)
.def_property_readonly("max_num_generated_tokens", &GenLlmReq::getMaxNumGeneratedTokens)
.def("will_complete_next_iteration", &GenLlmReq::willCompleteNextIteration)
.def("add_new_token", &GenLlmReq::addNewToken, py::arg("token"), py::arg("beam"))
.def("add_new_tokens", &GenLlmReq::addNewTokens, py::arg("beam_tokens"))
.def_property_readonly("num_draft_tokens", &GenLlmReq::getNumDraftTokens)
Expand Down
1 change: 0 additions & 1 deletion tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def __init__(
self.llm_args.batch_wait_timeout_iters = 0
self.llm_args.batch_wait_max_tokens_ratio = 0.0
self.llm_args.max_num_tokens = seq_info.max_num_tokens
self.iter_counter = 0

# NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor...
self.max_beam_width = max_beam_width
Expand Down
19 changes: 11 additions & 8 deletions tensorrt_llm/_torch/expert_statistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,15 @@ def create(rank_id: int):
rank_id, start, stop)

@staticmethod
def set_iter(iter_id: int) -> bool:
def should_record() -> bool:
if ExpertStatistic.expert_statistic_obj is not None:
return ExpertStatistic.expert_statistic_obj._set_iter(iter_id)
else:
return False
return ExpertStatistic.expert_statistic_obj._should_record
return False

@staticmethod
def set_iter(iter_id: int) -> None:
if ExpertStatistic.expert_statistic_obj is not None:
ExpertStatistic.expert_statistic_obj._set_iter(iter_id)

@staticmethod
def set_layer(layer_id: int) -> None:
Expand All @@ -57,10 +61,10 @@ def __init__(self, rank_id: int, start: int, stop: int) -> None:
self._records = {}

@property
def should_record(self) -> bool:
def _should_record(self) -> bool:
return self.current_iter_id is not None and self.start <= self.current_iter_id < self.stop

def _set_iter(self, iter_id: int) -> bool:
def _set_iter(self, iter_id: int) -> None:
self.current_iter_id = iter_id
if iter_id == self.stop:
logger.info(
Expand All @@ -74,14 +78,13 @@ def _set_iter(self, iter_id: int) -> bool:
json.dump(self._meta_info, f)
safetensors.torch.save_file(
self._records, f"{path}/rank{self.rank_id}.safetensors")
return self.should_record

def _set_layer(self, layer: int) -> None:
self.current_layer = layer

def _maybe_add_info(self, expert_count: int,
token_selected_experts: torch.Tensor) -> None:
if not self.should_record:
if not self._should_record:
return

if self._meta_info is None:
Expand Down
3 changes: 1 addition & 2 deletions tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ def __del__(self):
def maybe_get_cuda_graph(
self,
batch: ScheduledRequests,
iter_counter: int,
enable_spec_decode: bool,
attn_metadata: Any,
spec_metadata: Optional[Any] = None,
Expand All @@ -180,7 +179,7 @@ def maybe_get_cuda_graph(
- The key for the graph, if applicable.
"""
# disable when doing statistic
if ExpertStatistic.set_iter(iter_counter):
if ExpertStatistic.should_record():
return None, None, None

can_run_cuda_graph = batch.can_run_cuda_graph
Expand Down
26 changes: 14 additions & 12 deletions tensorrt_llm/_torch/pyexecutor/handle_additional_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import torch

from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest,
LlmRequestState)
from tensorrt_llm._utils import nvtx_range
from tensorrt_llm.logger import logger

Expand Down Expand Up @@ -92,18 +93,19 @@ def __call__(
(1, beam_width, 1)))

for llm_req in generation_requests:
additional_outputs = llm_req.py_additional_outputs
if llm_req.state != LlmRequestState.GENERATION_COMPLETE:
additional_outputs = llm_req.py_additional_outputs

for name in additional_outputs:
outputs_begin = (output_index_with_context
if gather_context[name] else
output_index_without_context)
outputs_end = outputs_begin + beam_width

output_device_view = outputs[name][
outputs_begin:outputs_end].reshape(1, beam_width, -1)
llm_req.py_result.append_additional_generation_outputs(
name, output_device_view)
for name in additional_outputs:
outputs_begin = (output_index_with_context
if gather_context[name] else
output_index_without_context)
outputs_end = outputs_begin + beam_width

output_device_view = outputs[name][
outputs_begin:outputs_end].reshape(1, beam_width, -1)
llm_req.py_result.append_additional_generation_outputs(
name, output_device_view)

output_index_with_context += beam_width
output_index_without_context += beam_width
Expand Down
6 changes: 5 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/handle_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import torch

from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest,
LlmRequestState)
from tensorrt_llm._utils import nvtx_range
from tensorrt_llm.logger import logger

Expand Down Expand Up @@ -72,6 +73,9 @@ def __call__(

total_context_logits = num_context_logits_prefix_sum[-1]
for batch_index, llm_req in enumerate(generation_requests):
if llm_req.state == LlmRequestState.GENERATION_COMPLETE:
continue

logits_begin = total_context_logits + batch_index * beam_width
logits_end = logits_begin + beam_width

Expand Down
65 changes: 42 additions & 23 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,19 @@ class LogitsStorage:

def __init__(
self,
*,
seq_length: int,
use_device_memory=True,
should_exclude_last=False,
extra_token_for_overlap_scheduler=False,
use_chunked_generation_logits=False,
chunk_size=8
): # logic adpted from HandleGenerationLogits.cpp to use chunked transfer
if should_exclude_last:
if extra_token_for_overlap_scheduler:
# Exclude last logits is used when overlap scheduler is used, that generates one extra token,
# so we should make sure there's memory for that extra +1.
seq_length += 1
self.seq_length = seq_length
self.use_device_memory = use_device_memory
self._should_exclude_last = should_exclude_last
self.use_chunked_generation_logits = use_chunked_generation_logits
self.chunk_size = chunk_size
self._logits_indices = []
Expand Down Expand Up @@ -126,14 +126,20 @@ def append(self, logits: torch.Tensor):
non_blocking=True)
self._logits_indices.append((position, new_position))

def get(self, all_logits: bool) -> torch.Tensor | None:
def get(self, all_logits: bool, exclude_last: bool) -> torch.Tensor | None:
"""Returns the used logits storage if there are any, otherwise, returns None.
When all_logits is True then all set logits are returned, otherwise, only the last logits are returned."""

Args:
all_logits: If True, return all logits; if False, return only the last chunk of logits.
exclude_last: If True, drop the entire last chunk. Requires at least 2 chunks to have been appended.
This is used when overlap scheduler is enabled to discard the extra iteration's logits.
"""
if self._storage is None:
return None

try:
last = -2 if self._should_exclude_last else -1
# When exclude_last=True, we expect at least 2 chunks and drop the whole last chunk
last = -2 if exclude_last else -1
start = 0 if all_logits else self._logits_indices[last][0]
end = self._logits_indices[last][1]
return self._storage[start:end]
Expand Down Expand Up @@ -175,9 +181,6 @@ def finalize_chunked_transfer(self):
if self.use_chunked_generation_logits and self._device_fragments:
self._transfer_chunk_to_host()

def set_exclude_last(self, should_exclude_last: bool) -> None:
self._should_exclude_last = should_exclude_last


class LogProbStorage:
beam_width: int = -1
Expand Down Expand Up @@ -225,6 +228,7 @@ class PyResult:
"""PyResult reimplements some features of `bindings.executor.Result` in Python"""

def __init__(self,
*,
prompt_len: int,
max_new_tokens: int,
use_device_memory=True,
Expand All @@ -240,16 +244,20 @@ def __init__(self,
assert chunk_size == 1, "chunk_size must be 1 in streaming mode"
self._streaming = streaming
self._chunk_size = chunk_size
self._exclude_last_generation_logits = exclude_last_generation_logits

# Note that in C++ implemnetation both context logits and generation logits are stored on host memory.
# Here we only use host memory for generation logits if in chunked model.
self._context_logits = LogitsStorage(
prompt_len, use_device_memory, use_chunked_generation_logits=False
seq_length=prompt_len,
use_device_memory=use_device_memory,
extra_token_for_overlap_scheduler=False,
use_chunked_generation_logits=False
) if return_context_logits else None
self._generation_logits = LogitsStorage(
max_new_tokens,
use_device_memory,
exclude_last_generation_logits,
seq_length=max_new_tokens,
use_device_memory=use_device_memory,
extra_token_for_overlap_scheduler=exclude_last_generation_logits,
use_chunked_generation_logits=use_chunked_generation_logits,
chunk_size=self._chunk_size) if return_generation_logits else None
self._log_probs = LogProbStorage() if return_log_probs else None
Expand All @@ -263,6 +271,10 @@ def __init__(self,
for name in additional_outputs
} if additional_outputs else None

def set_exclude_last_generation_logits(
self, exclude_last_generation_logits: bool):
self._exclude_last_generation_logits = exclude_last_generation_logits

def append_context_logits(self, context_logits: torch.Tensor):
if self._context_logits:
self._context_logits.append(context_logits)
Expand Down Expand Up @@ -309,7 +321,7 @@ def set_log_probs(self, log_probs: list[TokenLogprobs],
@property
def context_logits(self) -> torch.Tensor | None:
if self._context_logits is None or (storage := self._context_logits.get(
all_logits=True)) is None:
all_logits=True, exclude_last=False)) is None:
return None
return storage[:, 0] # remove beam_width axis for context

Expand All @@ -320,7 +332,9 @@ def generation_logits(self) -> torch.Tensor | None:
if not self._generation_logits:
return None

storage = self._generation_logits.get(all_logits=not self._streaming)
storage = self._generation_logits.get(
all_logits=not self._streaming,
exclude_last=self._exclude_last_generation_logits)
if storage is None:
return None
return storage.transpose(0, 1)
Expand Down Expand Up @@ -524,14 +538,14 @@ def __init__(
self.py_stop_words_list = stop_words_list

self.py_result = PyResult(
self.py_prompt_len,
self.py_max_new_tokens,
return_logits_device_memory,
self.streaming,
return_log_probs,
return_context_logits,
return_generation_logits,
exclude_last_generation_logits,
prompt_len=self.py_prompt_len,
max_new_tokens=self.py_max_new_tokens,
use_device_memory=return_logits_device_memory,
streaming=self.streaming,
return_log_probs=return_log_probs,
return_context_logits=return_context_logits,
return_generation_logits=return_generation_logits,
exclude_last_generation_logits=exclude_last_generation_logits,
use_chunked_generation_logits=self.py_use_chunked_generation_logits,
chunk_size=self.py_logits_chunk_size,
additional_outputs=additional_outputs)
Expand All @@ -545,6 +559,11 @@ def __init__(
else:
self._py_embedding_bias_1d = self.embedding_bias

def set_exclude_last_generation_logits(
self, exclude_last_generation_logits: bool):
self.py_result.set_exclude_last_generation_logits(
exclude_last_generation_logits)

@property
def cached_tokens(self) -> int:
return self._cached_tokens
Expand Down
3 changes: 0 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,6 @@ def __init__(
if self.use_mrope:
self.mrope_position_ids_cuda = torch.empty(
(3, 1, self.max_num_tokens), dtype=torch.int, device='cuda')
self.iter_counter = 0

# Pre-allocated buffers for draft model to avoid implicit synchronization
# These are used to build index tensors without creating tensors from Python lists
Expand Down Expand Up @@ -2572,7 +2571,6 @@ def forward(self,

maybe_attn_metadata, maybe_spec_metadata, key = self.cuda_graph_runner.maybe_get_cuda_graph(
padded_requests,
iter_counter=self.iter_counter,
enable_spec_decode=self.enable_spec_decode,
attn_metadata=attn_metadata,
spec_metadata=spec_metadata,
Expand All @@ -2596,7 +2594,6 @@ def forward(self,
new_tensors_device, cache_indirection_buffer,
num_accepted_tokens_device, req_id_to_old_request)

self.iter_counter += 1
with with_shared_pool(self.cuda_graph_runner.get_graph_pool()):
if not can_run_graph:
# Fallback to eager execution if graph was not used
Expand Down
Loading
Loading