- 
                Notifications
    
You must be signed in to change notification settings  - Fork 1.8k
 
Open
Labels
InvestigatingKV-Cache Managementkv-cache management for efficient LLM inferencekv-cache management for efficient LLM inferenceMemoryMemory utilization in TRTLLM: leak/OOM handling, footprint optimization, memory profiling.Memory utilization in TRTLLM: leak/OOM handling, footprint optimization, memory profiling.Pytorch<NV>Pytorch backend related issues<NV>Pytorch backend related issuesbugSomething isn't workingSomething isn't workingtriagedIssue has been triaged by maintainersIssue has been triaged by maintainers
Description
System Info
OS: 22.04
CUDA version: 13.0
GPU model(s): NVIDIA A100-40G
Driver version: 580.65.06
TensorRT-LLM version: e689a73 (Commits on Oct 30, 2025)
Who can help?
No response
Information
- The official example scripts
 - My own modified scripts
 
Tasks
-  An officially supported task in the 
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
 
Reproduction
trtllm-serve command
CUDA_VISIBLE_DEVICES=0,1,2,3 trtllm-serve EXAONE-4.0-32B --backend pytorch --tp_size 4 --extra_llm_api_options config.yml
config.yml
kv_cache_config:
  enable_block_reuse: false
  max_attention_window: [4096,4096,4096,131072]
enable_chunked_prefill: true
Expected behavior
The model without VSWA works as expected with TP=4.
However, the model with VSWA, which should also run successfully through trtllm-serve, encounters an OOM error instead.
nvidia-smi (without VSWA)
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.65.06              Driver Version: 580.65.06      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100-SXM4-40GB          On  |   00000000:07:00.0 Off |                    0 |
| N/A   35C    P0             59W /  400W |   37837MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100-SXM4-40GB          On  |   00000000:0B:00.0 Off |                    0 |
| N/A   37C    P0             62W /  400W |   37539MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA A100-SXM4-40GB          On  |   00000000:48:00.0 Off |                    0 |
| N/A   33C    P0             55W /  400W |   37539MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA A100-SXM4-40GB          On  |   00000000:4C:00.0 Off |                    0 |
| N/A   37C    P0             57W /  400W |   37395MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
actual behavior
RuntimeError: Executor creation failed due to insufficient GPU memory.
Error log
[10/30/2025-09:08:50] [TRT-LLM] [RANK 0] [I] Run generation-only CUDA graph warmup for batch size=29, draft_len=0
[10/30/2025-09:08:51] [TRT-LLM] [RANK 0] [I] For VSWA case, we return the minimum of the number of free blocks for each window size: {4096: 7029, 131072: 21085}
[10/30/2025-09:08:51] [TRT-LLM] [RANK 0] [I] For VSWA case, we return the minimum of the number of free blocks for each window size: {4096: 7002, 131072: 21058}
[10/30/2025-09:08:51] [TRT-LLM] [RANK 0] [I] Run generation-only CUDA graph warmup for batch size=28, draft_len=0
[10/30/2025-09:08:51] [TRT-LLM] [RANK 0] [I] For VSWA case, we return the minimum of the number of free blocks for each window size: {4096: 7029, 131072: 21085}
[10/30/2025-09:08:51] [TRT-LLM] [RANK 0] [I] For VSWA case, we return the minimum of the number of free blocks for each window size: {4096: 7003, 131072: 21059}
[10/30/2025-09:08:51] [TRT-LLM] [RANK 0] [I] Run generation-only CUDA graph warmup for batch size=27, draft_len=0
[10/30/2025-09:08:52] [TRT-LLM] [RANK 0] [I] For VSWA case, we return the minimum of the number of free blocks for each window size: {4096: 7029, 131072: 21085}
[10/30/2025-09:08:52] [TRT-LLM] [RANK 0] [I] For VSWA case, we return the minimum of the number of free blocks for each window size: {4096: 7004, 131072: 21060}
[10/30/2025-09:08:52] [TRT-LLM] [RANK 0] [I] Run generation-only CUDA graph warmup for batch size=26, draft_len=0
[10/30/2025-09:08:52] [TRT-LLM] [RANK 0] [E] Failed to initialize executor on rank 0: Executor creation failed due to insufficient GPU memory.
The following component could not be created: Additional executor resources (temporary for KV cache size estimation)
Total GPU memory (GiB): 39.49
Free GPU memory before component creation attempt (GiB): 2.28
Previously created components and free GPU memory before/after creation (GiB):
Model: 38.65 / 22.88
Guided decoder: 22.88 / 22.88
Sampler: 22.88 / 22.88
Initial KV cache (temporary for KV cache size estimation): 22.88 / 2.28
Drafter:
           ^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/py_executor.py", line 271, in __init__
    self.model_engine.warmup(self.resource_manager)
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 473, in wrapper
    return method(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 512, in warmup
    self._run_cuda_graph_warmup(resource_manager)
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 592, in _run_cuda_graph_warmup
    self._capture_generation_cuda_graphs(resource_manager)
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 652, in _capture_generation_cuda_graphs
    self.forward(batch,
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/utils.py", line 84, in wrapper
    return func(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 2311, in forward
    maybe_graph, maybe_attn_metadata, maybe_spec_metadata, key = self.cuda_graph_runner.maybe_get_cuda_graph(
                                                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py", line 192, in maybe_get_cuda_graph
    attn_metadata = self.attn_metadata.create_cuda_graph_metadata(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/attention_backend/interface.py", line 320, in create_cuda_graph_metadata
    cuda_graph_metadata.__post_init__()
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/attention_backend/tr
  File "/usr/lib/python3.12/contextlib.py", line 158, in __exit__
    self.gen.throw(value)
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py", line 147, in observe_creation_stage
    raise RuntimeError(explanation) from e
RuntimeError: Executor creation failed due to insufficient GPU memory.
The following component could not be created: Additional executor resources (temporary for KV cache size estimation)
Total GPU memory (GiB): 39.49
Free GPU memory before component creation attempt (GiB): 2.28
Previously created components and free GPU memory before/after creation (GiB):
Model: 38.65 / 22.88
Guided decoder: 22.88 / 22.88
Sampler: 22.88 / 22.88
Initial KV cache (temporary for KV cache size estimation): 22.88 / 2.28
Drafter: 2.28 / 2.28
Please refer to the TensorRT LLM documentation for information on how to control the memory usage through TensorRT LLM configuration options. Possible options include:
  Model: reduce max_num_tokens and/or shard the model weights across GPUs by enabling pipeline and/or tensor parallelism
  Sampler: reduce max_seq_len and/or max_attention_window_size
  Initial KV cache (temporary for KV cache size estimation): reduce max_num_tokens
  Drafter: reduce max_seq_len and/or max_draft_len
  Additional executor resources (temporary for KV cache size estimation): reduce max_num_tokens
[10/30/2025-09:08:52] [TRT-LLM] [I] get signal from executor worker
[10/30/2025-09:08:52] [TRT-LLM] [E] Executor worker initialization error: Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 592, in _run_cuda_graph_warmup
    self._capture_generation_cuda_graphs(resource_manager)
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 652, in _capture_generation_cuda_graphs
    self.forward(batch,
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/utils.py", line 84, in wrapper
    return func(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 2311, in forward
    maybe_graph, maybe_attn_metadata, maybe_spec_metadata, key = self.cuda_graph_runner.maybe_get_cuda_graph(
                                                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py", line 192, in maybe_get_cuda_graph
    attn_metadata = self.attn_metadata.create_cuda_graph_metadata(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/attention_backend/interface.py", line 320, in create_cuda_graph_metadata
    cuda_graph_metadata.__post_init__()
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/attention_backend/trtllm.py", line 634, in __post_init__
    self._post_init_with_buffers(self.cuda_graph_buffers)
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/attention_backend/trtllm.py", line 702, in _post_init_with_buffers
    self.kv_cache_block_offsets = get_empty(
                                  ^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/attention_backend/trtllm.py", line 667, in get_empty
    return buffers.get_buffer(tensor_shape, dtype, cache_name,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/memory_buffer_utils.py", line 99, in get_b
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py", line 587, in create_py_executor
    with mem_monitor.observe_creation_stage(
  File "/usr/lib/python3.12/contextlib.py", line 158, in __exit__
    self.gen.throw(value)
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py", line 147, in observe_creation_stage
    raise RuntimeError(explanation) from e
RuntimeError: Executor creation failed due to insufficient GPU memory.
nvidia-smi (with VSWA)
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.65.06              Driver Version: 580.65.06      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100-SXM4-40GB          On  |   00000000:07:00.0 Off |                    0 |
| N/A   33C    P0             59W /  400W |   40413MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100-SXM4-40GB          On  |   00000000:0B:00.0 Off |                    0 |
| N/A   43C    P0            136W /  400W |   40311MiB /  40960MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA A100-SXM4-40GB          On  |   00000000:48:00.0 Off |                    0 |
| N/A   39C    P0            128W /  400W |   40311MiB /  40960MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA A100-SXM4-40GB          On  |   00000000:4C:00.0 Off |                    0 |
| N/A   43C    P0            131W /  400W |   40167MiB /  40960MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
additional notes
Is there a recommended value for setting the TRTLLM_WINDOW_SIZE_SHARES environment variable?
Model: EXAONE-4.0-32B
Hardware: A100-40G
TP size: 4
max_attention_window: [4096,4096,4096,131072]
Related to #7923
Before submitting a new issue...
- Make sure you already searched for relevant issues, and checked the documentation and examples for answers to frequently asked questions.
 
Metadata
Metadata
Assignees
Labels
InvestigatingKV-Cache Managementkv-cache management for efficient LLM inferencekv-cache management for efficient LLM inferenceMemoryMemory utilization in TRTLLM: leak/OOM handling, footprint optimization, memory profiling.Memory utilization in TRTLLM: leak/OOM handling, footprint optimization, memory profiling.Pytorch<NV>Pytorch backend related issues<NV>Pytorch backend related issuesbugSomething isn't workingSomething isn't workingtriagedIssue has been triaged by maintainersIssue has been triaged by maintainers