Skip to content

[Bug]: EXAONE-4.0-32B with VSWA causes OOM on A100-40G x 4 #8802

@lkm2835

Description

@lkm2835

System Info

OS: 22.04
CUDA version: 13.0
GPU model(s): NVIDIA A100-40G
Driver version: 580.65.06
TensorRT-LLM version: e689a73 (Commits on Oct 30, 2025)

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

trtllm-serve command

CUDA_VISIBLE_DEVICES=0,1,2,3 trtllm-serve EXAONE-4.0-32B --backend pytorch --tp_size 4 --extra_llm_api_options config.yml

config.yml

kv_cache_config:
  enable_block_reuse: false
  max_attention_window: [4096,4096,4096,131072]
enable_chunked_prefill: true

Expected behavior

The model without VSWA works as expected with TP=4.
However, the model with VSWA, which should also run successfully through trtllm-serve, encounters an OOM error instead.

nvidia-smi (without VSWA)

+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.65.06              Driver Version: 580.65.06      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100-SXM4-40GB          On  |   00000000:07:00.0 Off |                    0 |
| N/A   35C    P0             59W /  400W |   37837MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100-SXM4-40GB          On  |   00000000:0B:00.0 Off |                    0 |
| N/A   37C    P0             62W /  400W |   37539MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA A100-SXM4-40GB          On  |   00000000:48:00.0 Off |                    0 |
| N/A   33C    P0             55W /  400W |   37539MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA A100-SXM4-40GB          On  |   00000000:4C:00.0 Off |                    0 |
| N/A   37C    P0             57W /  400W |   37395MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

actual behavior

RuntimeError: Executor creation failed due to insufficient GPU memory.

Error log

[10/30/2025-09:08:50] [TRT-LLM] [RANK 0] [I] Run generation-only CUDA graph warmup for batch size=29, draft_len=0
[10/30/2025-09:08:51] [TRT-LLM] [RANK 0] [I] For VSWA case, we return the minimum of the number of free blocks for each window size: {4096: 7029, 131072: 21085}
[10/30/2025-09:08:51] [TRT-LLM] [RANK 0] [I] For VSWA case, we return the minimum of the number of free blocks for each window size: {4096: 7002, 131072: 21058}
[10/30/2025-09:08:51] [TRT-LLM] [RANK 0] [I] Run generation-only CUDA graph warmup for batch size=28, draft_len=0
[10/30/2025-09:08:51] [TRT-LLM] [RANK 0] [I] For VSWA case, we return the minimum of the number of free blocks for each window size: {4096: 7029, 131072: 21085}
[10/30/2025-09:08:51] [TRT-LLM] [RANK 0] [I] For VSWA case, we return the minimum of the number of free blocks for each window size: {4096: 7003, 131072: 21059}
[10/30/2025-09:08:51] [TRT-LLM] [RANK 0] [I] Run generation-only CUDA graph warmup for batch size=27, draft_len=0
[10/30/2025-09:08:52] [TRT-LLM] [RANK 0] [I] For VSWA case, we return the minimum of the number of free blocks for each window size: {4096: 7029, 131072: 21085}
[10/30/2025-09:08:52] [TRT-LLM] [RANK 0] [I] For VSWA case, we return the minimum of the number of free blocks for each window size: {4096: 7004, 131072: 21060}
[10/30/2025-09:08:52] [TRT-LLM] [RANK 0] [I] Run generation-only CUDA graph warmup for batch size=26, draft_len=0
[10/30/2025-09:08:52] [TRT-LLM] [RANK 0] [E] Failed to initialize executor on rank 0: Executor creation failed due to insufficient GPU memory.

The following component could not be created: Additional executor resources (temporary for KV cache size estimation)
Total GPU memory (GiB): 39.49
Free GPU memory before component creation attempt (GiB): 2.28

Previously created components and free GPU memory before/after creation (GiB):
Model: 38.65 / 22.88
Guided decoder: 22.88 / 22.88
Sampler: 22.88 / 22.88
Initial KV cache (temporary for KV cache size estimation): 22.88 / 2.28
Drafter:
           ^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/py_executor.py", line 271, in __init__
    self.model_engine.warmup(self.resource_manager)
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 473, in wrapper
    return method(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 512, in warmup
    self._run_cuda_graph_warmup(resource_manager)
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 592, in _run_cuda_graph_warmup
    self._capture_generation_cuda_graphs(resource_manager)
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 652, in _capture_generation_cuda_graphs
    self.forward(batch,
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/utils.py", line 84, in wrapper
    return func(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 2311, in forward
    maybe_graph, maybe_attn_metadata, maybe_spec_metadata, key = self.cuda_graph_runner.maybe_get_cuda_graph(
                                                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py", line 192, in maybe_get_cuda_graph
    attn_metadata = self.attn_metadata.create_cuda_graph_metadata(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/attention_backend/interface.py", line 320, in create_cuda_graph_metadata
    cuda_graph_metadata.__post_init__()
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/attention_backend/tr
  File "/usr/lib/python3.12/contextlib.py", line 158, in __exit__
    self.gen.throw(value)
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py", line 147, in observe_creation_stage
    raise RuntimeError(explanation) from e
RuntimeError: Executor creation failed due to insufficient GPU memory.

The following component could not be created: Additional executor resources (temporary for KV cache size estimation)
Total GPU memory (GiB): 39.49
Free GPU memory before component creation attempt (GiB): 2.28

Previously created components and free GPU memory before/after creation (GiB):
Model: 38.65 / 22.88
Guided decoder: 22.88 / 22.88
Sampler: 22.88 / 22.88
Initial KV cache (temporary for KV cache size estimation): 22.88 / 2.28
Drafter: 2.28 / 2.28

Please refer to the TensorRT LLM documentation for information on how to control the memory usage through TensorRT LLM configuration options. Possible options include:
  Model: reduce max_num_tokens and/or shard the model weights across GPUs by enabling pipeline and/or tensor parallelism
  Sampler: reduce max_seq_len and/or max_attention_window_size
  Initial KV cache (temporary for KV cache size estimation): reduce max_num_tokens
  Drafter: reduce max_seq_len and/or max_draft_len
  Additional executor resources (temporary for KV cache size estimation): reduce max_num_tokens

[10/30/2025-09:08:52] [TRT-LLM] [I] get signal from executor worker
[10/30/2025-09:08:52] [TRT-LLM] [E] Executor worker initialization error: Traceback (most recent call last):

  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 592, in _run_cuda_graph_warmup
    self._capture_generation_cuda_graphs(resource_manager)
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 652, in _capture_generation_cuda_graphs
    self.forward(batch,
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/utils.py", line 84, in wrapper
    return func(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 2311, in forward
    maybe_graph, maybe_attn_metadata, maybe_spec_metadata, key = self.cuda_graph_runner.maybe_get_cuda_graph(
                                                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py", line 192, in maybe_get_cuda_graph
    attn_metadata = self.attn_metadata.create_cuda_graph_metadata(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/attention_backend/interface.py", line 320, in create_cuda_graph_metadata
    cuda_graph_metadata.__post_init__()
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/attention_backend/trtllm.py", line 634, in __post_init__
    self._post_init_with_buffers(self.cuda_graph_buffers)
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/attention_backend/trtllm.py", line 702, in _post_init_with_buffers
    self.kv_cache_block_offsets = get_empty(
                                  ^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/attention_backend/trtllm.py", line 667, in get_empty
    return buffers.get_buffer(tensor_shape, dtype, cache_name,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/memory_buffer_utils.py", line 99, in get_b
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py", line 587, in create_py_executor
    with mem_monitor.observe_creation_stage(
  File "/usr/lib/python3.12/contextlib.py", line 158, in __exit__
    self.gen.throw(value)
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py", line 147, in observe_creation_stage
    raise RuntimeError(explanation) from e
RuntimeError: Executor creation failed due to insufficient GPU memory.

nvidia-smi (with VSWA)

+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.65.06              Driver Version: 580.65.06      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100-SXM4-40GB          On  |   00000000:07:00.0 Off |                    0 |
| N/A   33C    P0             59W /  400W |   40413MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100-SXM4-40GB          On  |   00000000:0B:00.0 Off |                    0 |
| N/A   43C    P0            136W /  400W |   40311MiB /  40960MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA A100-SXM4-40GB          On  |   00000000:48:00.0 Off |                    0 |
| N/A   39C    P0            128W /  400W |   40311MiB /  40960MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA A100-SXM4-40GB          On  |   00000000:4C:00.0 Off |                    0 |
| N/A   43C    P0            131W /  400W |   40167MiB /  40960MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

additional notes

Is there a recommended value for setting the TRTLLM_WINDOW_SIZE_SHARES environment variable?
Model: EXAONE-4.0-32B
Hardware: A100-40G
TP size: 4
max_attention_window: [4096,4096,4096,131072]

Related to #7923

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and checked the documentation and examples for answers to frequently asked questions.

Metadata

Metadata

Labels

InvestigatingKV-Cache Managementkv-cache management for efficient LLM inferenceMemoryMemory utilization in TRTLLM: leak/OOM handling, footprint optimization, memory profiling.Pytorch<NV>Pytorch backend related issuesbugSomething isn't workingtriagedIssue has been triaged by maintainers

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions