From e5f9e046159eef4960328fe678ba6953ea9a9d09 Mon Sep 17 00:00:00 2001 From: Robin Kobus <19427718+Funatiq@users.noreply.github.com> Date: Wed, 22 Oct 2025 15:20:30 +0000 Subject: [PATCH 1/9] [None][feat] Enable early exit with overlap scheduler - Update MicroBatchScheduler bindings to skip scheduling after GENERATION_TO_COMPLETE state. - Update PyExecutor to set GENERATION_TO_COMPLETE state for requests that will complete next iteration. - Fix _executor_loop_overlap to finish previous batch if current batch is empty. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> --- cpp/tensorrt_llm/batch_manager/llmRequest.cpp | 5 +- .../nanobind/batch_manager/algorithms.cpp | 2 +- .../nanobind/batch_manager/bindings.cpp | 1 + .../pybind/batch_manager/algorithms.cpp | 4 +- .../pybind/batch_manager/bindings.cpp | 1 + tensorrt_llm/_torch/pyexecutor/py_executor.py | 75 +++++++++++-------- 6 files changed, 53 insertions(+), 35 deletions(-) diff --git a/cpp/tensorrt_llm/batch_manager/llmRequest.cpp b/cpp/tensorrt_llm/batch_manager/llmRequest.cpp index f8b74d7d48e..322ef196234 100644 --- a/cpp/tensorrt_llm/batch_manager/llmRequest.cpp +++ b/cpp/tensorrt_llm/batch_manager/llmRequest.cpp @@ -69,7 +69,10 @@ void LlmRequest::createSerializedResult( /// Note that there is some dependency on the order of operations in this method. Modify with care! std::optional LlmRequest::createResult(bool useFastLogits, int32_t mpiWorldRank) { - if (!(isFinished() || (mIsStreaming && mState == LlmRequestState::kGENERATION_IN_PROGRESS))) + if (!(isFinished() + || (mIsStreaming + && (mState == LlmRequestState::kGENERATION_IN_PROGRESS + || mState == LlmRequestState::kGENERATION_TO_COMPLETE)))) { return std::nullopt; } diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/algorithms.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/algorithms.cpp index 8e8b7d483e1..8ff7f2be0a7 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/algorithms.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/algorithms.cpp @@ -64,7 +64,7 @@ void tensorrt_llm::nanobind::batch_manager::algorithms::initBindings(nb::module_ LlmRequestState>(), nb::arg("ctx_chunk_config") = std::nullopt, nb::arg("max_context_length") = std::nullopt, nb::arg("no_schedule_until_state") = LlmRequestState::kCONTEXT_INIT, - nb::arg("no_schedule_after_state") = LlmRequestState::kGENERATION_COMPLETE) + nb::arg("no_schedule_after_state") = LlmRequestState::kGENERATION_TO_COMPLETE) .def("__call__", &MicroBatchScheduler::operator(), nb::arg("active_requests"), nb::arg("inflight_req_ids"), nb::arg("max_batch_size_runtime"), nb::arg("max_num_tokens_runtime")) .def("name", [](MicroBatchScheduler const&) { return MicroBatchScheduler::name; }); diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp index d6149755e3e..aaa9f0c4ecb 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp @@ -103,6 +103,7 @@ void initBindings(nb::module_& m) .def("get_last_tokens", nb::overload_cast<>(&GenLlmReq::getLastTokens)) .def("get_beam_width_by_iter", &GenLlmReq::getBeamWidthByIter, nb::arg("for_next_iteration") = false) .def_prop_ro("max_num_generated_tokens", &GenLlmReq::getMaxNumGeneratedTokens) + .def("will_complete_next_iteration", &GenLlmReq::willCompleteNextIteration) .def("add_new_token", &GenLlmReq::addNewToken, nb::arg("token"), nb::arg("beam")) .def("add_new_tokens", &GenLlmReq::addNewTokens, nb::arg("beam_tokens")) .def_prop_ro("num_draft_tokens", &GenLlmReq::getNumDraftTokens) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp b/cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp index 9361c1bd565..5573931e238 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp @@ -65,8 +65,8 @@ void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::mod LlmRequestState>(), py::arg("ctx_chunk_config") = std::nullopt, py::arg("max_context_length") = std::nullopt, py::arg_v("no_schedule_until_state", LlmRequestState::kCONTEXT_INIT, "LlmRequestState.CONTEXT_INIT"), - py::arg_v("no_schedule_after_state", LlmRequestState::kGENERATION_COMPLETE, - "LlmRequestState.GENERATION_COMPLETE")) + py::arg_v("no_schedule_after_state", LlmRequestState::kGENERATION_TO_COMPLETE, + "LlmRequestState.GENERATION_TO_COMPLETE")) .def("__call__", &MicroBatchScheduler::operator(), py::arg("active_requests"), py::arg("inflight_req_ids"), py::arg("max_batch_size_runtime"), py::arg("max_num_tokens_runtime")) .def("name", [](MicroBatchScheduler const&) { return MicroBatchScheduler::name; }); diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index ecaffdda6aa..f4091607aa1 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -107,6 +107,7 @@ void initBindings(pybind11::module_& m) .def("get_last_tokens", py::overload_cast<>(&GenLlmReq::getLastTokens)) .def("get_beam_width_by_iter", &GenLlmReq::getBeamWidthByIter, py::arg("for_next_iteration") = false) .def_property_readonly("max_num_generated_tokens", &GenLlmReq::getMaxNumGeneratedTokens) + .def("will_complete_next_iteration", &GenLlmReq::willCompleteNextIteration) .def("add_new_token", &GenLlmReq::addNewToken, py::arg("token"), py::arg("beam")) .def("add_new_tokens", &GenLlmReq::addNewTokens, py::arg("beam_tokens")) .def_property_readonly("num_draft_tokens", &GenLlmReq::getNumDraftTokens) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 23ab0dbfa07..10d1808067f 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -826,7 +826,7 @@ def _executor_loop_pp(self): self.num_scheduled_requests = scheduled_batch.batch_size logger.debug( - f'has {len(self.active_requests)} active_request, ' + f'has {len(self.active_requests)} active_requests, ' f'scheduled {len(scheduled_batch.context_requests)} context requests and ' f'{len(scheduled_batch.generation_requests)} generation requests' ) @@ -1085,7 +1085,7 @@ def _prepare_and_schedule_batch(self): self.num_scheduled_requests = scheduled_batch.batch_size logger.debug( - f'has {len(self.active_requests)} active_request, ' + f'has {len(self.active_requests)} active_requests, ' f'scheduled {len(scheduled_batch.context_requests)} context requests and ' f'{len(scheduled_batch.generation_requests)} generation requests') return scheduled_batch, iter_stats @@ -1394,17 +1394,18 @@ def _executor_loop_overlap(self): self.guided_decoder.add_batch(scheduled_batch) self.guided_decoder.init_disagg_gen_requests() - previous_tensors = self.previous_batch and self.previous_batch.sample_state - # If there are previous draft tokens, we need to update the target requests to accept some draft tokens. - # When there's any accepted tokens, we can't directly use the previous batch's outputs in this iteration for the target model, - # so we'll set the target model's input to None and skip updating the target requests after target model forward. - use_previous_draft_tokens = self.has_previous_draft_tokens - if self.drafter is not None and (self.use_spec_decode or - use_previous_draft_tokens): - target_inputs = self._handle_speculative_decoding( - scheduled_batch, previous_tensors, - previous_tensors_device) + previous_tensors = self.previous_batch and self.previous_batch.sample_state + # If there are previous draft tokens, we need to update the target requests to accept some draft tokens. + # When there's any accepted tokens, we can't directly use the previous batch's outputs in this iteration for the target model, + # so we'll set the target model's input to None and skip updating the target requests after target model forward. + use_previous_draft_tokens = self.has_previous_draft_tokens + if self.drafter is not None and (self.use_spec_decode + or use_previous_draft_tokens): + target_inputs = self._handle_speculative_decoding( + scheduled_batch, previous_tensors, + previous_tensors_device) + if can_queue: # Use the draft_model's outputs if we've launched the draft model. # Otherwise, use the previous batch's outputs. if (target_inputs is not None @@ -1417,25 +1418,26 @@ def _executor_loop_overlap(self): batch_outputs = self._forward_step(scheduled_batch, previous_tensors_device) - if self.previous_batch is not None: - self._update_requests(self.previous_batch.sample_state) + if self.previous_batch is not None: + self._update_requests(self.previous_batch.sample_state) - if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver: - for req in self.previous_batch.sample_state.scheduled_requests.context_requests: - if req.is_context_only_request and ( - req.is_context_finished - or req.is_finished_due_to_length): - block_id = self.kv_cache_manager.store_blocks_for_reuse( - req, True) - self.ctx_in_transmission_requests[ - req.py_request_id] = ( - (req, block_id, - self.ctx_in_transmission_counter)) + if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver: + for req in self.previous_batch.sample_state.scheduled_requests.context_requests: + if req.is_context_only_request and ( + req.is_context_finished + or req.is_finished_due_to_length): + block_id = self.kv_cache_manager.store_blocks_for_reuse( + req, True) + self.ctx_in_transmission_requests[ + req.py_request_id] = ( + (req, block_id, + self.ctx_in_transmission_counter)) - if self.drafter is not None and self.use_spec_decode: - # Cleanup previous draft resources used in the draft model - self.drafter.cleanup_previous_draft_resources() + if self.drafter is not None and self.use_spec_decode: + # Cleanup previous draft resources used in the draft model + self.drafter.cleanup_previous_draft_resources() + if can_queue: if self.guided_decoder is not None: # add_batch must be called again to have updated new tokens. self.guided_decoder.add_batch(scheduled_batch) @@ -1451,9 +1453,10 @@ def _executor_loop_overlap(self): scheduled_batch.context_requests ) if self.kv_cache_transceiver else [] - if self.previous_batch is not None: - self._process_previous_batch() + if self.previous_batch is not None: + self._process_previous_batch() + if can_queue: if self.enable_iter_perf_stats: iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[ 'num_ctx_tokens'] @@ -2012,7 +2015,17 @@ def _update_request_states_tp(self, scheduled_requests: ScheduledRequests): request.context_chunk_size) request.move_to_next_context_chunk() if request.context_remaining_length == 0: - request.state = LlmRequestState.GENERATION_IN_PROGRESS + if not self.disable_overlap_scheduler and request.will_complete_next_iteration( + ): + request.state = LlmRequestState.GENERATION_TO_COMPLETE + else: + request.state = LlmRequestState.GENERATION_IN_PROGRESS + + for request in scheduled_requests.generation_requests: + if request.state != LlmRequestState.GENERATION_COMPLETE: + if not self.disable_overlap_scheduler and request.will_complete_next_iteration( + ): + request.state = LlmRequestState.GENERATION_TO_COMPLETE def _update_request_states_star_attention( self, scheduled_requests: ScheduledRequests): From d79787c5bd1aa29a3a331f8712e1f90ccd09f89c Mon Sep 17 00:00:00 2001 From: Robin Kobus <19427718+Funatiq@users.noreply.github.com> Date: Fri, 24 Oct 2025 13:33:08 +0000 Subject: [PATCH 2/9] [fix] Skip logits and additional outputs handling in extra iteration for overlap scheduler Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> --- .../pyexecutor/handle_additional_outputs.py | 26 ++++++++++--------- .../_torch/pyexecutor/handle_logits.py | 6 ++++- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/handle_additional_outputs.py b/tensorrt_llm/_torch/pyexecutor/handle_additional_outputs.py index b9588f809cd..db2a7c0017b 100644 --- a/tensorrt_llm/_torch/pyexecutor/handle_additional_outputs.py +++ b/tensorrt_llm/_torch/pyexecutor/handle_additional_outputs.py @@ -3,7 +3,8 @@ import torch -from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest +from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest, + LlmRequestState) from tensorrt_llm._utils import nvtx_range from tensorrt_llm.logger import logger @@ -92,18 +93,19 @@ def __call__( (1, beam_width, 1))) for llm_req in generation_requests: - additional_outputs = llm_req.py_additional_outputs + if llm_req.state != LlmRequestState.GENERATION_COMPLETE: + additional_outputs = llm_req.py_additional_outputs - for name in additional_outputs: - outputs_begin = (output_index_with_context - if gather_context[name] else - output_index_without_context) - outputs_end = outputs_begin + beam_width - - output_device_view = outputs[name][ - outputs_begin:outputs_end].reshape(1, beam_width, -1) - llm_req.py_result.append_additional_generation_outputs( - name, output_device_view) + for name in additional_outputs: + outputs_begin = (output_index_with_context + if gather_context[name] else + output_index_without_context) + outputs_end = outputs_begin + beam_width + + output_device_view = outputs[name][ + outputs_begin:outputs_end].reshape(1, beam_width, -1) + llm_req.py_result.append_additional_generation_outputs( + name, output_device_view) output_index_with_context += beam_width output_index_without_context += beam_width diff --git a/tensorrt_llm/_torch/pyexecutor/handle_logits.py b/tensorrt_llm/_torch/pyexecutor/handle_logits.py index cb940feb098..17c390735c9 100644 --- a/tensorrt_llm/_torch/pyexecutor/handle_logits.py +++ b/tensorrt_llm/_torch/pyexecutor/handle_logits.py @@ -3,7 +3,8 @@ import torch -from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest +from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest, + LlmRequestState) from tensorrt_llm._utils import nvtx_range from tensorrt_llm.logger import logger @@ -72,6 +73,9 @@ def __call__( total_context_logits = num_context_logits_prefix_sum[-1] for batch_index, llm_req in enumerate(generation_requests): + if llm_req.state == LlmRequestState.GENERATION_COMPLETE: + continue + logits_begin = total_context_logits + batch_index * beam_width logits_end = logits_begin + beam_width From 8e143c78cb665c6c28dee4a7d6f812d94b6d7918 Mon Sep 17 00:00:00 2001 From: Robin Kobus <19427718+Funatiq@users.noreply.github.com> Date: Fri, 24 Oct 2025 13:35:05 +0000 Subject: [PATCH 3/9] [fix] Generation logits length for overlap scheduler early exit Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/llm_request.py | 65 ++++++++++++------- tensorrt_llm/_torch/pyexecutor/py_executor.py | 2 + 2 files changed, 44 insertions(+), 23 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index c525481fee3..01d3f35f876 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -43,19 +43,19 @@ class LogitsStorage: def __init__( self, + *, seq_length: int, use_device_memory=True, - should_exclude_last=False, + extra_token_for_overlap_scheduler=False, use_chunked_generation_logits=False, chunk_size=8 ): # logic adpted from HandleGenerationLogits.cpp to use chunked transfer - if should_exclude_last: + if extra_token_for_overlap_scheduler: # Exclude last logits is used when overlap scheduler is used, that generates one extra token, # so we should make sure there's memory for that extra +1. seq_length += 1 self.seq_length = seq_length self.use_device_memory = use_device_memory - self._should_exclude_last = should_exclude_last self.use_chunked_generation_logits = use_chunked_generation_logits self.chunk_size = chunk_size self._logits_indices = [] @@ -126,14 +126,20 @@ def append(self, logits: torch.Tensor): non_blocking=True) self._logits_indices.append((position, new_position)) - def get(self, all_logits: bool) -> torch.Tensor | None: + def get(self, all_logits: bool, exclude_last: bool) -> torch.Tensor | None: """Returns the used logits storage if there are any, otherwise, returns None. - When all_logits is True then all set logits are returned, otherwise, only the last logits are returned.""" + + Args: + all_logits: If True, return all logits; if False, return only the last chunk of logits. + exclude_last: If True, drop the entire last chunk. Requires at least 2 chunks to have been appended. + This is used when overlap scheduler is enabled to discard the extra iteration's logits. + """ if self._storage is None: return None try: - last = -2 if self._should_exclude_last else -1 + # When exclude_last=True, we expect at least 2 chunks and drop the whole last chunk + last = -2 if exclude_last else -1 start = 0 if all_logits else self._logits_indices[last][0] end = self._logits_indices[last][1] return self._storage[start:end] @@ -175,9 +181,6 @@ def finalize_chunked_transfer(self): if self.use_chunked_generation_logits and self._device_fragments: self._transfer_chunk_to_host() - def set_exclude_last(self, should_exclude_last: bool) -> None: - self._should_exclude_last = should_exclude_last - class LogProbStorage: beam_width: int = -1 @@ -225,6 +228,7 @@ class PyResult: """PyResult reimplements some features of `bindings.executor.Result` in Python""" def __init__(self, + *, prompt_len: int, max_new_tokens: int, use_device_memory=True, @@ -240,16 +244,20 @@ def __init__(self, assert chunk_size == 1, "chunk_size must be 1 in streaming mode" self._streaming = streaming self._chunk_size = chunk_size + self._exclude_last_generation_logits = exclude_last_generation_logits # Note that in C++ implemnetation both context logits and generation logits are stored on host memory. # Here we only use host memory for generation logits if in chunked model. self._context_logits = LogitsStorage( - prompt_len, use_device_memory, use_chunked_generation_logits=False + seq_length=prompt_len, + use_device_memory=use_device_memory, + extra_token_for_overlap_scheduler=False, + use_chunked_generation_logits=False ) if return_context_logits else None self._generation_logits = LogitsStorage( - max_new_tokens, - use_device_memory, - exclude_last_generation_logits, + seq_length=max_new_tokens, + use_device_memory=use_device_memory, + extra_token_for_overlap_scheduler=exclude_last_generation_logits, use_chunked_generation_logits=use_chunked_generation_logits, chunk_size=self._chunk_size) if return_generation_logits else None self._log_probs = LogProbStorage() if return_log_probs else None @@ -263,6 +271,10 @@ def __init__(self, for name in additional_outputs } if additional_outputs else None + def set_exclude_last_generation_logits( + self, exclude_last_generation_logits: bool): + self._exclude_last_generation_logits = exclude_last_generation_logits + def append_context_logits(self, context_logits: torch.Tensor): if self._context_logits: self._context_logits.append(context_logits) @@ -309,7 +321,7 @@ def set_log_probs(self, log_probs: list[TokenLogprobs], @property def context_logits(self) -> torch.Tensor | None: if self._context_logits is None or (storage := self._context_logits.get( - all_logits=True)) is None: + all_logits=True, exclude_last=False)) is None: return None return storage[:, 0] # remove beam_width axis for context @@ -320,7 +332,9 @@ def generation_logits(self) -> torch.Tensor | None: if not self._generation_logits: return None - storage = self._generation_logits.get(all_logits=not self._streaming) + storage = self._generation_logits.get( + all_logits=not self._streaming, + exclude_last=self._exclude_last_generation_logits) if storage is None: return None return storage.transpose(0, 1) @@ -524,14 +538,14 @@ def __init__( self.py_stop_words_list = stop_words_list self.py_result = PyResult( - self.py_prompt_len, - self.py_max_new_tokens, - return_logits_device_memory, - self.streaming, - return_log_probs, - return_context_logits, - return_generation_logits, - exclude_last_generation_logits, + prompt_len=self.py_prompt_len, + max_new_tokens=self.py_max_new_tokens, + use_device_memory=return_logits_device_memory, + streaming=self.streaming, + return_log_probs=return_log_probs, + return_context_logits=return_context_logits, + return_generation_logits=return_generation_logits, + exclude_last_generation_logits=exclude_last_generation_logits, use_chunked_generation_logits=self.py_use_chunked_generation_logits, chunk_size=self.py_logits_chunk_size, additional_outputs=additional_outputs) @@ -545,6 +559,11 @@ def __init__( else: self._py_embedding_bias_1d = self.embedding_bias + def set_exclude_last_generation_logits( + self, exclude_last_generation_logits: bool): + self.py_result.set_exclude_last_generation_logits( + exclude_last_generation_logits) + @property def cached_tokens(self) -> int: return self._cached_tokens diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 10d1808067f..9d5bc663ca4 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -2017,6 +2017,7 @@ def _update_request_states_tp(self, scheduled_requests: ScheduledRequests): if request.context_remaining_length == 0: if not self.disable_overlap_scheduler and request.will_complete_next_iteration( ): + request.set_exclude_last_generation_logits(False) request.state = LlmRequestState.GENERATION_TO_COMPLETE else: request.state = LlmRequestState.GENERATION_IN_PROGRESS @@ -2025,6 +2026,7 @@ def _update_request_states_tp(self, scheduled_requests: ScheduledRequests): if request.state != LlmRequestState.GENERATION_COMPLETE: if not self.disable_overlap_scheduler and request.will_complete_next_iteration( ): + request.set_exclude_last_generation_logits(False) request.state = LlmRequestState.GENERATION_TO_COMPLETE def _update_request_states_star_attention( From 7fb0ccdc6f72c6e4aa30172523b55d28c6b758a0 Mon Sep 17 00:00:00 2001 From: Robin Kobus <19427718+Funatiq@users.noreply.github.com> Date: Mon, 27 Oct 2025 12:16:42 +0000 Subject: [PATCH 4/9] Improve test_return_logits - Add loop over sequences in test_generate_with_return_logits and test_generate_async_with_return_logits. - Add assertion on sequence length in test_generate_with_return_logits and test_generate_async_with_return_logits. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> --- .../_torch/sampler/test_return_logits.py | 67 ++++++++++--------- 1 file changed, 36 insertions(+), 31 deletions(-) diff --git a/tests/unittest/_torch/sampler/test_return_logits.py b/tests/unittest/_torch/sampler/test_return_logits.py index 3b6d133c91e..f2bf9d0b616 100644 --- a/tests/unittest/_torch/sampler/test_return_logits.py +++ b/tests/unittest/_torch/sampler/test_return_logits.py @@ -154,20 +154,23 @@ def test_generate_with_return_logits( else: assert output.context_logits is None - if gather_generation_logits: - gen_logits = output.outputs[0].generation_logits - assert gen_logits is not None - assert gen_logits.ndim == 2 - assert gen_logits.shape[0] == sampling_params.max_tokens - assert torch.argmax(gen_logits, - dim=1).tolist() == output.outputs[0].token_ids - else: - assert output.outputs[0].generation_logits is None + for sequence in output.outputs: + assert sequence.length == sampling_params.max_tokens + + if gather_generation_logits: + gen_logits = sequence.generation_logits + assert gen_logits is not None + assert gen_logits.ndim == 2 + assert gen_logits.shape[0] == sampling_params.max_tokens + assert torch.argmax(gen_logits, + dim=1).tolist() == sequence.token_ids + else: + assert sequence.generation_logits is None - if return_log_probs: - assert len(output.outputs[0].logprobs) == sampling_params.max_tokens - else: - assert len(output.outputs[0].logprobs) == 0 + if return_log_probs: + assert len(sequence.logprobs) == sampling_params.max_tokens + else: + assert len(sequence.logprobs) == 0 @force_ampere # Save H100 resource @@ -218,22 +221,24 @@ def test_generate_async_with_return_logits( else: assert output.context_logits is None - if gather_generation_logits: - gen_logits = output.outputs[0].generation_logits - assert gen_logits is not None - assert gen_logits.ndim == 2 - assert gen_logits.shape[0] == 1 - try: - assert torch.argmax( - gen_logits, - dim=1).tolist()[0] == output.outputs[0].token_ids[-1] - except AssertionError: - # FIXME: Remove xfail once the bug is fixed - pytest.xfail("Known bug: https://nvbugs/5573238") - else: - assert output.outputs[0].generation_logits is None + for sequence in output.outputs: + assert sequence.length == idx + 1 + + if gather_generation_logits: + gen_logits = sequence.generation_logits + assert gen_logits is not None + assert gen_logits.ndim == 2 + assert gen_logits.shape[0] == 1 + try: + assert torch.argmax( + gen_logits, dim=1).tolist()[0] == sequence.token_ids[-1] + except AssertionError: + # FIXME: Remove xfail once the bug is fixed + pytest.xfail("Known bug: https://nvbugs/5573238") + else: + assert sequence.generation_logits is None - if return_log_probs: - assert len(output.outputs[0].logprobs) == idx + 1 - else: - assert len(output.outputs[0].logprobs) == 0 + if return_log_probs: + assert len(sequence.logprobs) == idx + 1 + else: + assert len(sequence.logprobs) == 0 From 2a26a189f5698c17cb69866ea05b323eb56882b9 Mon Sep 17 00:00:00 2001 From: Robin Kobus <19427718+Funatiq@users.noreply.github.com> Date: Mon, 27 Oct 2025 19:12:30 +0000 Subject: [PATCH 5/9] [fix] Fix test_llm_api_connector - Reduce call_counts when using the overlap scheduler. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> --- .../defs/llmapi/test_llm_api_connector.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/tests/integration/defs/llmapi/test_llm_api_connector.py b/tests/integration/defs/llmapi/test_llm_api_connector.py index ba624b299d0..0b01f0ac018 100644 --- a/tests/integration/defs/llmapi/test_llm_api_connector.py +++ b/tests/integration/defs/llmapi/test_llm_api_connector.py @@ -89,8 +89,7 @@ def test_connector_simple(enforce_single_worker, model_with_connector, assert len(scheduler.update_state_after_alloc.call_args.args[1]) == 1 # With the overlap scheduler, we generate one extra token. - assert scheduler.build_connector_meta.call_count == NUM_TOKENS + int( - use_overlap_scheduler) + assert scheduler.build_connector_meta.call_count == NUM_TOKENS # We should have a single `SchedulerOutput` per forward pass. for i, call in enumerate(scheduler.build_connector_meta.call_args_list): @@ -110,8 +109,7 @@ def test_connector_simple(enforce_single_worker, model_with_connector, assert len(scheduler_output.cached_requests[0].new_tokens) == 1 # We call `start_load_kv` once at the beginning of each forward pass. - assert worker.start_load_kv.call_count == NUM_TOKENS + int( - use_overlap_scheduler) + assert worker.start_load_kv.call_count == NUM_TOKENS # Only called once when the request is received. assert scheduler.get_num_new_matched_tokens.call_count == 1 @@ -120,10 +118,8 @@ def test_connector_simple(enforce_single_worker, model_with_connector, for call in worker.wait_for_layer_load.call_args_list) + 1 # Called num_layers * num_forward_passes times. - assert worker.wait_for_layer_load.call_count == num_layers * ( - NUM_TOKENS + int(use_overlap_scheduler)) - assert worker.save_kv_layer.call_count == num_layers * ( - NUM_TOKENS + int(use_overlap_scheduler)) + assert worker.wait_for_layer_load.call_count == num_layers * (NUM_TOKENS) + assert worker.save_kv_layer.call_count == num_layers * (NUM_TOKENS) for i, call in enumerate(worker.wait_for_layer_load.call_args_list): assert call.args[0] == i % num_layers @@ -131,8 +127,7 @@ def test_connector_simple(enforce_single_worker, model_with_connector, for i, call in enumerate(worker.save_kv_layer.call_args_list): assert call.args[0] == i % num_layers - assert worker.wait_for_save.call_count == NUM_TOKENS + int( - use_overlap_scheduler) + assert worker.wait_for_save.call_count == NUM_TOKENS assert scheduler.request_finished.call_count == 1 @@ -239,9 +234,7 @@ def test_connector_scheduler_output(enforce_single_worker, model_with_connector, scheduler.update_state_after_alloc.call_args.args[1]) == math.ceil( NUM_INPUT_TOKENS / BLOCK_SIZE) - # Additional token when using the overlap scheduler. - assert scheduler.build_connector_meta.call_count == NUM_TOKENS + int( - use_overlap_scheduler) + assert scheduler.build_connector_meta.call_count == NUM_TOKENS for i, call in enumerate(scheduler.build_connector_meta.call_args_list): sched_output = call.args[0] From 75fb0e39dd592d401ca8f2b6a1940766a886b184 Mon Sep 17 00:00:00 2001 From: Robin Kobus <19427718+Funatiq@users.noreply.github.com> Date: Thu, 30 Oct 2025 15:29:57 +0000 Subject: [PATCH 6/9] Refactor test_llm_perf_metrics to use context manager for LLM instance - Updated the test to utilize a context manager for the LLM instance, improving resource management. - Maintained existing assertions to validate performance metrics. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> --- tests/unittest/llmapi/test_llm_pytorch.py | 46 ++++++++++++----------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 5a069567f1c..f2ccb04d7c9 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -182,28 +182,30 @@ def test_llm_reward_model(): @skip_ray def test_llm_perf_metrics(): - llm = LLM(model=llama_model_path, kv_cache_config=global_kvcache_config) - sampling_params = SamplingParams(max_tokens=10, return_perf_metrics=True) - outputs = llm.generate(prompts, sampling_params) - assert outputs[0].outputs[0].request_perf_metrics is not None - - perf_metrics = outputs[0].outputs[0].request_perf_metrics - - timing_metrics = perf_metrics.timing_metrics - assert timing_metrics.arrival_time < timing_metrics.first_scheduled_time - assert timing_metrics.first_scheduled_time < timing_metrics.first_token_time - assert timing_metrics.first_token_time < timing_metrics.last_token_time - - kv_cache_metrics = perf_metrics.kv_cache_metrics - assert kv_cache_metrics.num_total_allocated_blocks == 1 - assert kv_cache_metrics.num_new_allocated_blocks == 1 - assert kv_cache_metrics.num_reused_blocks == 0 - assert kv_cache_metrics.num_missed_blocks == 1 - assert kv_cache_metrics.kv_cache_hit_rate == 0 - - assert perf_metrics.first_iter is not None - assert perf_metrics.iter - perf_metrics.first_iter == sampling_params.max_tokens - 1 - assert perf_metrics.last_iter == perf_metrics.iter + with LLM(model=llama_model_path, + kv_cache_config=global_kvcache_config) as llm: + sampling_params = SamplingParams(max_tokens=10, + return_perf_metrics=True) + outputs = llm.generate(prompts, sampling_params) + assert outputs[0].outputs[0].request_perf_metrics is not None + + perf_metrics = outputs[0].outputs[0].request_perf_metrics + + timing_metrics = perf_metrics.timing_metrics + assert timing_metrics.arrival_time < timing_metrics.first_scheduled_time + assert timing_metrics.first_scheduled_time < timing_metrics.first_token_time + assert timing_metrics.first_token_time < timing_metrics.last_token_time + + kv_cache_metrics = perf_metrics.kv_cache_metrics + assert kv_cache_metrics.num_total_allocated_blocks == 1 + assert kv_cache_metrics.num_new_allocated_blocks == 1 + assert kv_cache_metrics.num_reused_blocks == 0 + assert kv_cache_metrics.num_missed_blocks == 1 + assert kv_cache_metrics.kv_cache_hit_rate == 0 + + assert perf_metrics.first_iter is not None + assert perf_metrics.iter - perf_metrics.first_iter == sampling_params.max_tokens - 1 + assert perf_metrics.last_iter == perf_metrics.iter @skip_ray From 48d5c1e83952c10d4b06da597bab094ea196c4c7 Mon Sep 17 00:00:00 2001 From: Robin Kobus <19427718+Funatiq@users.noreply.github.com> Date: Mon, 3 Nov 2025 11:37:18 +0000 Subject: [PATCH 7/9] [refactor] Move iter_counter handling to PyExecutor - Moved iter_counter in PyExecutor to ensure consistency in tracking iterations. - This allows tracking of iteration where scheduled requests are empty. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> --- .../_torch/auto_deploy/shim/ad_executor.py | 1 - tensorrt_llm/_torch/expert_statistic.py | 19 +++++++++++-------- .../_torch/pyexecutor/cuda_graph_runner.py | 3 +-- .../_torch/pyexecutor/model_engine.py | 3 --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 17 +++++++++++++---- 5 files changed, 25 insertions(+), 18 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index 5e857fd636e..f818cb76bce 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -153,7 +153,6 @@ def __init__( self.llm_args.batch_wait_timeout_iters = 0 self.llm_args.batch_wait_max_tokens_ratio = 0.0 self.llm_args.max_num_tokens = seq_info.max_num_tokens - self.iter_counter = 0 # NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor... self.max_beam_width = max_beam_width diff --git a/tensorrt_llm/_torch/expert_statistic.py b/tensorrt_llm/_torch/expert_statistic.py index 98dc127f4e0..34cdc647a65 100644 --- a/tensorrt_llm/_torch/expert_statistic.py +++ b/tensorrt_llm/_torch/expert_statistic.py @@ -29,11 +29,15 @@ def create(rank_id: int): rank_id, start, stop) @staticmethod - def set_iter(iter_id: int) -> bool: + def should_record() -> bool: if ExpertStatistic.expert_statistic_obj is not None: - return ExpertStatistic.expert_statistic_obj._set_iter(iter_id) - else: - return False + return ExpertStatistic.expert_statistic_obj._should_record + return False + + @staticmethod + def set_iter(iter_id: int) -> None: + if ExpertStatistic.expert_statistic_obj is not None: + ExpertStatistic.expert_statistic_obj._set_iter(iter_id) @staticmethod def set_layer(layer_id: int) -> None: @@ -57,10 +61,10 @@ def __init__(self, rank_id: int, start: int, stop: int) -> None: self._records = {} @property - def should_record(self) -> bool: + def _should_record(self) -> bool: return self.current_iter_id is not None and self.start <= self.current_iter_id < self.stop - def _set_iter(self, iter_id: int) -> bool: + def _set_iter(self, iter_id: int) -> None: self.current_iter_id = iter_id if iter_id == self.stop: logger.info( @@ -74,14 +78,13 @@ def _set_iter(self, iter_id: int) -> bool: json.dump(self._meta_info, f) safetensors.torch.save_file( self._records, f"{path}/rank{self.rank_id}.safetensors") - return self.should_record def _set_layer(self, layer: int) -> None: self.current_layer = layer def _maybe_add_info(self, expert_count: int, token_selected_experts: torch.Tensor) -> None: - if not self.should_record: + if not self._should_record: return if self._meta_info is None: diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index 92869ca401b..d924f2ea457 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -164,7 +164,6 @@ def __del__(self): def maybe_get_cuda_graph( self, batch: ScheduledRequests, - iter_counter: int, enable_spec_decode: bool, attn_metadata: Any, spec_metadata: Optional[Any] = None, @@ -180,7 +179,7 @@ def maybe_get_cuda_graph( - The key for the graph, if applicable. """ # disable when doing statistic - if ExpertStatistic.set_iter(iter_counter): + if ExpertStatistic.should_record(): return None, None, None can_run_cuda_graph = batch.can_run_cuda_graph diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 42f71181b11..6bcd79a3746 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -364,7 +364,6 @@ def __init__( if self.use_mrope: self.mrope_position_ids_cuda = torch.empty( (3, 1, self.max_num_tokens), dtype=torch.int, device='cuda') - self.iter_counter = 0 # Pre-allocated buffers for draft model to avoid implicit synchronization # These are used to build index tensors without creating tensors from Python lists @@ -2572,7 +2571,6 @@ def forward(self, maybe_attn_metadata, maybe_spec_metadata, key = self.cuda_graph_runner.maybe_get_cuda_graph( padded_requests, - iter_counter=self.iter_counter, enable_spec_decode=self.enable_spec_decode, attn_metadata=attn_metadata, spec_metadata=spec_metadata, @@ -2596,7 +2594,6 @@ def forward(self, new_tensors_device, cache_indirection_buffer, num_accepted_tokens_device, req_id_to_old_request) - self.iter_counter += 1 with with_shared_pool(self.cuda_graph_runner.get_graph_pool()): if not can_run_graph: # Fallback to eager execution if graph was not used diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 9d5bc663ca4..e2265554edf 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -11,6 +11,7 @@ import torch +from tensorrt_llm._torch.expert_statistic import ExpertStatistic from tensorrt_llm.serve.responses_utils import get_steady_clock_now_in_seconds try: @@ -137,6 +138,7 @@ def __init__(self, self.peft_cache_config = peft_cache_config + self.iter_counter = 0 # profile config self.profile_start_iters, self.profile_stop_iters = _load_iteration_indexes( PROFILE_START_STOP_ENV_VAR_NAME) @@ -575,7 +577,7 @@ def profile_step(): formatted_timestamp = datetime.datetime.now().strftime( "%Y-%m-%d %H:%M:%S") logger.info( - f"iter = {self.model_engine.iter_counter}, " + f"iter = {self.iter_counter}, " f"global_rank = {self.global_rank}, " f"rank = {self.dist.rank}, " f"currank_total_requests = {self.executor_request_queue.num_fetch_requests_cur_rank}/" @@ -705,7 +707,7 @@ def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests, stats.cpu_mem_usage = 0 stats.pinned_mem_usage = 0 - stats.iter = self.model_engine.iter_counter + stats.iter = self.iter_counter kv_cache_manager = self.resource_manager.resource_managers.get( ResourceManagerType.KV_CACHE_MANAGER) @@ -1004,6 +1006,8 @@ def _executor_loop_pp(self): self.active_requests, previous_batch) + self.iter_counter += 1 + def wait_on_pp_send_handles(self, microbatch_id): if self.send_handles[microbatch_id] is not None: self.send_handles[microbatch_id].wait() @@ -1240,6 +1244,8 @@ def _executor_loop(self): iter_stats=iter_stats, iter_start_time=iter_start_time)) + self.iter_counter += 1 + def _prepare_draft_requests(self): try: # Set draft tokens here to make the KV cache manager @@ -1473,6 +1479,8 @@ def _executor_loop_overlap(self): self._kv_connector_terminate_requests() + self.iter_counter += 1 + def _accept_draft_tokens( self, scheduled_batch: ScheduledRequests, target_outputs: SampleStateTensors, @@ -1964,9 +1972,10 @@ def _check_disagg_gen_cache_transfer_status(self, atLeastNum: int = 0): def _forward_step(self, scheduled_requests, new_tensors_device: Optional[SampleStateTensors] = None): + ExpertStatistic.set_iter(self.iter_counter) @nvtx_range( - f"[Executor] _forward_step {self.model_engine.iter_counter + 1}: {len(scheduled_requests.context_requests)} ctx reqs, {len(scheduled_requests.generation_requests)} gen reqs" + f"[Executor] _forward_step {self.iter_counter}: {len(scheduled_requests.context_requests)} ctx reqs, {len(scheduled_requests.generation_requests)} gen reqs" ) def forward(scheduled_requests, resource_manager, new_tensors_device, gather_context_logits, cache_indirection_buffer): @@ -2304,7 +2313,7 @@ def _handle_responses(self): # Skip active requests that are not scheduled if request.return_perf_metrics and request.py_decoding_iter >= 1: - request.update_perf_metrics(self.model_engine.iter_counter) + request.update_perf_metrics(self.iter_counter) request_done = False if request.py_decoding_iter == 1 or request.is_finished or \ From 135648f9cb2a6261678630fd1132bd72effec9a8 Mon Sep 17 00:00:00 2001 From: Robin Kobus <19427718+Funatiq@users.noreply.github.com> Date: Sun, 16 Nov 2025 13:34:06 +0000 Subject: [PATCH 8/9] fixup! [None][feat] Enable early exit with overlap scheduler Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> --- cpp/tensorrt_llm/batch_manager/llmRequest.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/cpp/tensorrt_llm/batch_manager/llmRequest.cpp b/cpp/tensorrt_llm/batch_manager/llmRequest.cpp index 322ef196234..e664021db0b 100644 --- a/cpp/tensorrt_llm/batch_manager/llmRequest.cpp +++ b/cpp/tensorrt_llm/batch_manager/llmRequest.cpp @@ -69,10 +69,9 @@ void LlmRequest::createSerializedResult( /// Note that there is some dependency on the order of operations in this method. Modify with care! std::optional LlmRequest::createResult(bool useFastLogits, int32_t mpiWorldRank) { - if (!(isFinished() - || (mIsStreaming - && (mState == LlmRequestState::kGENERATION_IN_PROGRESS - || mState == LlmRequestState::kGENERATION_TO_COMPLETE)))) + auto const streamingInProgress = mIsStreaming + && (mState == LlmRequestState::kGENERATION_IN_PROGRESS || mState == LlmRequestState::kGENERATION_TO_COMPLETE); + if (!(isFinished() || streamingInProgress)) { return std::nullopt; } From 6348945f234d46d18ddc56f56a871d911183e985 Mon Sep 17 00:00:00 2001 From: Robin Kobus <19427718+Funatiq@users.noreply.github.com> Date: Sun, 16 Nov 2025 13:36:03 +0000 Subject: [PATCH 9/9] fixup! [None][feat] Enable early exit with overlap scheduler Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index e2265554edf..ee2d9cb1322 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1400,18 +1400,17 @@ def _executor_loop_overlap(self): self.guided_decoder.add_batch(scheduled_batch) self.guided_decoder.init_disagg_gen_requests() - previous_tensors = self.previous_batch and self.previous_batch.sample_state - # If there are previous draft tokens, we need to update the target requests to accept some draft tokens. - # When there's any accepted tokens, we can't directly use the previous batch's outputs in this iteration for the target model, - # so we'll set the target model's input to None and skip updating the target requests after target model forward. - use_previous_draft_tokens = self.has_previous_draft_tokens - if self.drafter is not None and (self.use_spec_decode - or use_previous_draft_tokens): - target_inputs = self._handle_speculative_decoding( - scheduled_batch, previous_tensors, - previous_tensors_device) + previous_tensors = self.previous_batch and self.previous_batch.sample_state + # If there are previous draft tokens, we need to update the target requests to accept some draft tokens. + # When there's any accepted tokens, we can't directly use the previous batch's outputs in this iteration for the target model, + # so we'll set the target model's input to None and skip updating the target requests after target model forward. + use_previous_draft_tokens = self.has_previous_draft_tokens + if self.drafter is not None and (self.use_spec_decode or + use_previous_draft_tokens): + target_inputs = self._handle_speculative_decoding( + scheduled_batch, previous_tensors, + previous_tensors_device) - if can_queue: # Use the draft_model's outputs if we've launched the draft model. # Otherwise, use the previous batch's outputs. if (target_inputs is not None