Skip to content

[Bug]: Deepseek V3 FP4 LWS PD Crash on B200 on Long Prompt #8205

@bryangopal

Description

@bryangopal

System Info

+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 575.57.08              Driver Version: 575.57.08      CUDA Version: 12.9     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA B200                    Off |   00000000:03:00.0 Off |                    0 |
| N/A   33C    P0            146W / 1000W |       0MiB / 183359MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA B200                    Off |   00000000:13:00.0 Off |                    0 |
| N/A   39C    P0            142W / 1000W |       0MiB / 183359MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA B200                    Off |   00000000:63:00.0 Off |                    0 |
| N/A   33C    P0            142W / 1000W |       0MiB / 183359MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA B200                    Off |   00000000:73:00.0 Off |                    0 |
| N/A   40C    P0            145W / 1000W |       0MiB / 183359MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA B200                    Off |   00000000:83:00.0 Off |                    0 |
| N/A   34C    P0            152W / 1000W |       0MiB / 183359MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA B200                    Off |   00000000:93:00.0 Off |                    0 |
| N/A   40C    P0            144W / 1000W |       0MiB / 183359MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA B200                    Off |   00000000:E3:00.0 Off |                    0 |
| N/A   34C    P0            142W / 1000W |       0MiB / 183359MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA B200                    Off |   00000000:F3:00.0 Off |                    0 |
| N/A   40C    P0            146W / 1000W |       0MiB / 183359MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2025 NVIDIA Corporation
Built on Tue_May_27_02:21:03_PDT_2025
Cuda compilation tools, release 12.9, V12.9.86
Build cuda_12.9.r12.9/compiler.36037853_0
Python 3.12.3
Name: tensorrt_llm
Version: 1.1.0rc2
Summary: TensorRT-LLM: A TensorRT Toolbox for Large Language Models
Home-page: https://github.com/NVIDIA/TensorRT-LLM
Author: NVIDIA Corporation
Author-email: 
License: Apache License 2.0
Location: /usr/local/lib/python3.12/dist-packages
Requires: accelerate, aenum, backoff, blake3, blobfile, build, click, click_option_group, colored, cuda-python, cytoolz, datasets, diffusers, einops, etcd3, evaluate, fastapi, flashinfer-python, h5py, jsonschema, lark, llguidance, matplotlib, meson, mpi4py, mpmath, msgspec, ninja, numpy, nvidia-cuda-nvrtc-cu12, nvidia-ml-py, nvidia-modelopt, nvidia-nccl-cu12, nvtx, omegaconf, onnx, onnx_graphsurgeon, openai, opencv-python-headless, optimum, ordered-set, pandas, peft, pillow, polygraphy, prometheus_client, prometheus_fastapi_instrumentator, protobuf, psutil, pulp, pydantic, pydantic-settings, pynvml, pyzmq, sentencepiece, setuptools, soundfile, StrEnum, tensorrt, tiktoken, torch, torchvision, transformers, triton, uvicorn, wheel, xgrammar
Required-by: 
---
Name: tensorrt
Version: 10.11.0.33
Summary: A high performance deep learning inference library
Home-page: https://github.com/nvidia/tensorrt
Author: NVIDIA Corporation
Author-email: 
License: Proprietary
Location: /usr/local/lib/python3.12/dist-packages
Requires: 
Required-by: tensorrt_llm
---
Name: torch
Version: 2.8.0a0+5228986c39.nv25.6
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3-Clause
Location: /usr/local/lib/python3.12/dist-packages
Requires: filelock, fsspec, jinja2, networkx, setuptools, sympy, typing-extensions
Required-by: accelerate, flash_attn, flashinfer-python, lightning-thunder, nvidia-modelopt, nvidia-resiliency-ext, optimum, peft, tensorrt_llm, torchprofile, torchvision, transformer_engine, xgrammar

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Running intranode 1P1D for Deepseek V3 FP4 with MTP. Once it's warmed up, I send a prompt of 100k tokens and the prefill worker crashes after 5 seconds.

Prefill:

cat >./prefill-extra-llm-api-config.yml<<EOF
    enable_iter_perf_stats: true
    print_iter_log: false
    cuda_graph_config:
        max_batch_size: 32
        enable_padding: false
    moe_config:
        backend: TRTLLM
        max_num_tokens: 32768
    speculative_config:
        decoding_type: MTP
        num_nextn_predict_layers: 3
        #use_relaxed_acceptance_for_thinking: true
        # relaxed_topk: 3
        # relaxed_delta: 0.6
    disable_overlap_scheduler: true
    enable_autotuner: false
    kv_cache_config:
        free_gpu_memory_fraction: 0.4
        enable_block_reuse: true
        enable_partial_reuse: false
    enable_chunked_prefill: true
    scheduler_config:
        context_chunking_policy: EQUAL_PROGRESS
    cache_transceiver_config:
        backend: UCX
        max_tokens_in_buffer: 32768
EOF


export TORCHDYNAMO_DISABLE=1

trtllm-serve "${MODEL_NAME}"\
  --host 0.0.0.0 \
  --port "$PORT" \
  --backend pytorch \
  --max_batch_size 32 \
  --max_num_tokens 32768 \
  --max_seq_len 162000 \
  --tp_size 4 --ep_size 1 \
  --extra_llm_api_options ./prefill-extra-llm-api-config.yml \
  --log_level info

Decode:

cat >./decode-extra-llm-api-config.yml<<EOF
    enable_iter_perf_stats: true
    print_iter_log: false
    cuda_graph_config:
        max_batch_size: 32
        enable_padding: false
    moe_config:
        backend: TRTLLM
        max_num_tokens: 32768
    speculative_config:
        decoding_type: MTP
        num_nextn_predict_layers: 3
        #use_relaxed_acceptance_for_thinking: true
        # relaxed_topk: 3
        # relaxed_delta: 0.6
    disable_overlap_scheduler: false
    enable_autotuner: false
    kv_cache_config:
        free_gpu_memory_fraction: 0.5
        enable_block_reuse: true
        enable_partial_reuse: false
    enable_chunked_prefill: true
    cache_transceiver_config:
        backend: UCX
        max_tokens_in_buffer: 32768
EOF

export TORCHDYNAMO_DISABLE=1

trtllm-serve "${MODEL_NAME}"\
  --host 0.0.0.0 \
  --port "$PORT" \
  --backend pytorch \
  --max_batch_size 32 \
  --max_num_tokens 32768 \
  --max_seq_len 162000 \
  --tp_size 4 --ep_size 1 \
  --extra_llm_api_options ./decode-extra-llm-api-config.yml \
  --log_level info

Orchestrator:

cat >./orchestrator-config.yml<<EOF
    hostname: localhost
    port: 8000
    backend: pytorch
    context_servers:
        num_instances: ${PREFILL_COUNT}
        urls:
${PREFILL_URL_LINES}
    generation_servers:
        num_instances: ${DECODE_COUNT}
        urls:
${DECODE_URL_LINES}
EOF

trtllm-serve disaggregated -c orchestrator-config.yml

Expected behavior

not crashing

actual behavior

[TensorRT-LLM][WARNING] CacheTransceiver getOrAllocateBuffers: bufferCoverTargetNum:0 < targetNum:1, may use dynamic buffer, it's better to increase MaxTokensInBuffer in cacheTransceiverConfig, otherwise, the performance may be degraded
[TensorRT-LLM][WARNING] CacheTransceiver getOrAllocateBuffers: bufferCoverTargetNum:0 < targetNum:1, may use dynamic buffer, it's better to increase MaxTokensInBuffer in cacheTransceiverConfig, otherwise, the performance may be degraded
[TensorRT-LLM][WARNING] CacheTransceiver getOrAllocateBuffers: bufferCoverTargetNum:0 < targetNum:1, may use dynamic buffer, it's better to increase MaxTokensInBuffer in cacheTransceiverConfig, otherwise, the performance may be degraded
[TensorRT-LLM][WARNING] CacheTransceiver getOrAllocateBuffers: bufferCoverTargetNum:0 < targetNum:1, may use dynamic buffer, it's better to increase MaxTokensInBuffer in cacheTransceiverConfig, otherwise, the performance may be degraded
[TensorRT-LLM][ERROR] tensorrt_llm::common::TllmException: [TensorRT-LLM][ERROR] CUDA runtime error in ::cudaFreeAsync(ptr, mCudaStream->get()): an illegal memory access was encountered (/src/tensorrt_llm/cpp/tensorrt_llm/runtime/tllmBuffers.h:132)
1       0x7f9a6b427875 /usr/local/lib/python3.12/dist-packages/tensorrt_llm/libs/libtensorrt_llm.so(+0x1b33875) [0x7f9a6b427875]
2       0x7f9a6b84ffd0 virtual thunk to tensorrt_llm::runtime::GenericTensor<tensorrt_llm::runtime::CudaAllocatorAsync>::~GenericTensor() + 144
3       0x7f9a6b7f0df2 std::vector<std::shared_ptr<tensorrt_llm::runtime::ITensor>, std::allocator<std::shared_ptr<tensorrt_llm::runtime::ITensor> > >::~vector() + 194
4       0x7f9a6b4473dc /usr/local/lib/python3.12/dist-packages/tensorrt_llm/libs/libtensorrt_llm.so(+0x1b533dc) [0x7f9a6b4473dc]
5       0x7f9a6c273145 tensorrt_llm::batch_manager::DataResponder::Impl::response() + 2213
6       0x7f9a6c26c81d std::_Function_handler<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> (), std::__future_base::_Task_setter<std::unique_ptr<std::__future_base::_Result<void>, std::__future_base::_Result_base::_Deleter>, std::thread::_Invoker<std::tuple<void (tensorrt_llm::batch_manager::DataResponder::Impl::*)() noexcept, tensorrt_llm::batch_manager::DataResponder::Impl*> >, void> >::_M_invoke(std::_Any_data const&) + 45
7       0x7f9a6c24db5d std::__future_base::_State_baseV2::_M_do_set(std::function<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> ()>*, bool*) + 45
8       0x7f9f4c83eed3 /usr/lib/x86_64-linux-gnu/libc.so.6(+0xa1ed3) [0x7f9f4c83eed3]
9       0x7f9a6c26dba8 std::__future_base::_Async_state_impl<std::thread::_Invoker<std::tuple<void (tensorrt_llm::batch_manager::DataResponder::Impl::*)() noexcept, tensorrt_llm::batch_manager::DataResponder::Impl*> >, void>::_M_run() + 248
10      0x7f9cc1352db4 /usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xecdb4) [0x7f9cc1352db4]
11      0x7f9f4c839aa4 /usr/lib/x86_64-linux-gnu/libc.so.6(+0x9caa4) [0x7f9f4c839aa4]
12      0x7f9f4c8c6c3c /usr/lib/x86_64-linux-gnu/libc.so.6(+0x129c3c) [0x7f9f4c8c6c3c]
[TensorRT-LLM][ERROR] Exception in sendAndRemoveResponse: [TensorRT-LLM][ERROR] CUDA runtime error in ::cudaStreamSynchronize(get()): an illegal memory access was encountered (/src/tensorrt_llm/cpp/include/tensorrt_llm/runtime/cudaStream.h:83)
1       0x7f9a929d6b5b void tensorrt_llm::common::check<cudaError>(cudaError, char const*, char const*, int) + 139
2       0x7f9a6c2535f7 tensorrt_llm::batch_manager::kv_cache_manager::MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession&) + 1543
3       0x7f9a6c273145 tensorrt_llm::batch_manager::DataResponder::Impl::response() + 2213
4       0x7f9a6c26c81d std::_Function_handler<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> (), std::__future_base::_Task_setter<std::unique_ptr<std::__future_base::_Result<void>, std::__future_base::_Result_base::_Deleter>, std::thread::_Invoker<std::tuple<void (tensorrt_llm::batch_manager::DataResponder::Impl::*)() noexcept, tensorrt_llm::batch_manager::DataResponder::Impl*> >, void> >::_M_invoke(std::_Any_data const&) + 45
5       0x7f9a6c24db5d std::__future_base::_State_baseV2::_M_do_set(std::function<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> ()>*, bool*) + 45
6       0x7f9f4c83eed3 /usr/lib/x86_64-linux-gnu/libc.so.6(+0xa1ed3) [0x7f9f4c83eed3]
7       0x7f9a6c26dba8 std::__future_base::_Async_state_impl<std::thread::_Invoker<std::tuple<void (tensorrt_llm::batch_manager::DataResponder::Impl::*)() noexcept, tensorrt_llm::batch_manager::DataResponder::Impl*> >, void>::_M_run() + 248
8       0x7f9cc1352db4 /usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xecdb4) [0x7f9cc1352db4]
9       0x7f9f4c839aa4 /usr/lib/x86_64-linux-gnu/libc.so.6(+0x9caa4) [0x7f9f4c839aa4]
10      0x7f9f4c8c6c3c /usr/lib/x86_64-linux-gnu/libc.so.6(+0x129c3c) [0x7f9f4c8c6c3c] 
[TensorRT-LLM][ERROR] tensorrt_llm::common::TllmException: [TensorRT-LLM][ERROR] CUDA runtime error in ::cudaFreeAsync(ptr, mCudaStream->get()): an illegal memory access was encountered (/src/tensorrt_llm/cpp/tensorrt_llm/runtime/tllmBuffers.h:132)
1       0x7fbfd4a27875 /usr/local/lib/python3.12/dist-packages/tensorrt_llm/libs/libtensorrt_llm.so(+0x1b33875) [0x7fbfd4a27875]
2       0x7fbfd4e4ffd0 virtual thunk to tensorrt_llm::runtime::GenericTensor<tensorrt_llm::runtime::CudaAllocatorAsync>::~GenericTensor() + 144
3       0x7fbfd4df0df2 std::vector<std::shared_ptr<tensorrt_llm::runtime::ITensor>, std::allocator<std::shared_ptr<tensorrt_llm::runtime::ITensor> > >::~vector() + 194
4       0x7fbfd4a473dc /usr/local/lib/python3.12/dist-packages/tensorrt_llm/libs/libtensorrt_llm.so(+0x1b533dc) [0x7fbfd4a473dc]
5       0x7fbfd5873145 tensorrt_llm::batch_manager::DataResponder::Impl::response() + 2213
6       0x7fbfd586c81d std::_Function_handler<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> (), std::__future_base::_Task_setter<std::unique_ptr<std::__future_base::_Result<void>, std::__future_base::_Result_base::_Deleter>, std::thread::_Invoker<std::tuple<void (tensorrt_llm::batch_manager::DataResponder::Impl::*)() noexcept, tensorrt_llm::batch_manager::DataResponder::Impl*> >, void> >::_M_invoke(std::_Any_data const&) + 45
7       0x7fbfd584db5d std::__future_base::_State_baseV2::_M_do_set(std::function<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> ()>*, bool*) + 45
8       0x7fc4ae249ed3 /usr/lib/x86_64-linux-gnu/libc.so.6(+0xa1ed3) [0x7fc4ae249ed3]
9       0x7fbfd586dba8 std::__future_base::_Async_state_impl<std::thread::_Invoker<std::tuple<void (tensorrt_llm::batch_manager::DataResponder::Impl::*)() noexcept, tensorrt_llm::batch_manager::DataResponder::Impl*> >, void>::_M_run() + 248
10      0x7fc22ad54db4 /usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xecdb4) [0x7fc22ad54db4]
11      0x7fc4ae244aa4 /usr/lib/x86_64-linux-gnu/libc.so.6(+0x9caa4) [0x7fc4ae244aa4]
12      0x7fc4ae2d1c3c /usr/lib/x86_64-linux-gnu/libc.so.6(+0x129c3c) [0x7fc4ae2d1c3c]
[TensorRT-LLM][ERROR] Exception in sendAndRemoveResponse: [TensorRT-LLM][ERROR] CUDA runtime error in ::cudaStreamSynchronize(get()): an illegal memory access was encountered (/src/tensorrt_llm/cpp/include/tensorrt_llm/runtime/cudaStream.h:83)
1       0x7fbffbfd6b5b void tensorrt_llm::common::check<cudaError>(cudaError, char const*, char const*, int) + 139
2       0x7fbfd58535f7 tensorrt_llm::batch_manager::kv_cache_manager::MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession&) + 1543
3       0x7fbfd5873145 tensorrt_llm::batch_manager::DataResponder::Impl::response() + 2213
4       0x7fbfd586c81d std::_Function_handler<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> (), std::__future_base::_Task_setter<std::unique_ptr<std::__future_base::_Result<void>, std::__future_base::_Result_base::_Deleter>, std::thread::_Invoker<std::tuple<void (tensorrt_llm::batch_manager::DataResponder::Impl::*)() noexcept, tensorrt_llm::batch_manager::DataResponder::Impl*> >, void> >::_M_invoke(std::_Any_data const&) + 45
5       0x7fbfd584db5d std::__future_base::_State_baseV2::_M_do_set(std::function<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> ()>*, bool*) + 45
6       0x7fc4ae249ed3 /usr/lib/x86_64-linux-gnu/libc.so.6(+0xa1ed3) [0x7fc4ae249ed3]
7       0x7fbfd586dba8 std::__future_base::_Async_state_impl<std::thread::_Invoker<std::tuple<void (tensorrt_llm::batch_manager::DataResponder::Impl::*)() noexcept, tensorrt_llm::batch_manager::DataResponder::Impl*> >, void>::_M_run() + 248
8       0x7fc22ad54db4 /usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xecdb4) [0x7fc22ad54db4]
9       0x7fc4ae244aa4 /usr/lib/x86_64-linux-gnu/libc.so.6(+0x9caa4) [0x7fc4ae244aa4]
10      0x7fc4ae2d1c3c /usr/lib/x86_64-linux-gnu/libc.so.6(+0x129c3c) [0x7fc4ae2d1c3c] 
[TensorRT-LLM][ERROR] tensorrt_llm::common::TllmException: [TensorRT-LLM][ERROR] CUDA runtime error in ::cudaFreeAsync(ptr, mCudaStream->get()): an illegal memory access was encountered (/src/tensorrt_llm/cpp/tensorrt_llm/runtime/tllmBuffers.h:132)
1       0x7eff5a227875 /usr/local/lib/python3.12/dist-packages/tensorrt_llm/libs/libtensorrt_llm.so(+0x1b33875) [0x7eff5a227875]
2       0x7eff5a64ffd0 virtual thunk to tensorrt_llm::runtime::GenericTensor<tensorrt_llm::runtime::CudaAllocatorAsync>::~GenericTensor() + 144
3       0x7eff5a5f0df2 std::vector<std::shared_ptr<tensorrt_llm::runtime::ITensor>, std::allocator<std::shared_ptr<tensorrt_llm::runtime::ITensor> > >::~vector() + 194
4       0x7eff5a2473dc /usr/local/lib/python3.12/dist-packages/tensorrt_llm/libs/libtensorrt_llm.so(+0x1b533dc) [0x7eff5a2473dc]
5       0x7eff5b073145 tensorrt_llm::batch_manager::DataResponder::Impl::response() + 2213
6       0x7eff5b06c81d std::_Function_handler<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> (), std::__future_base::_Task_setter<std::unique_ptr<std::__future_base::_Result<void>, std::__future_base::_Result_base::_Deleter>, std::thread::_Invoker<std::tuple<void (tensorrt_llm::batch_manager::DataResponder::Impl::*)() noexcept, tensorrt_llm::batch_manager::DataResponder::Impl*> >, void> >::_M_invoke(std::_Any_data const&) + 45
7       0x7eff5b04db5d std::__future_base::_State_baseV2::_M_do_set(std::function<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> ()>*, bool*) + 45
8       0x7f043360fed3 /usr/lib/x86_64-linux-gnu/libc.so.6(+0xa1ed3) [0x7f043360fed3]
9       0x7eff5b06dba8 std::__future_base::_Async_state_impl<std::thread::_Invoker<std::tuple<void (tensorrt_llm::batch_manager::DataResponder::Impl::*)() noexcept, tensorrt_llm::batch_manager::DataResponder::Impl*> >, void>::_M_run() + 248
10      0x7f01b0546db4 /usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xecdb4) [0x7f01b0546db4]
11      0x7f043360aaa4 /usr/lib/x86_64-linux-gnu/libc.so.6(+0x9caa4) [0x7f043360aaa4]
12      0x7f0433697c3c /usr/lib/x86_64-linux-gnu/libc.so.6(+0x129c3c) [0x7f0433697c3c]
[TensorRT-LLM][ERROR] Exception in sendAndRemoveResponse: [TensorRT-LLM][ERROR] CUDA runtime error in ::cudaStreamSynchronize(get()): an illegal memory access was encountered (/src/tensorrt_llm/cpp/include/tensorrt_llm/runtime/cudaStream.h:83)
1       0x7eff817d6b5b void tensorrt_llm::common::check<cudaError>(cudaError, char const*, char const*, int) + 139
2       0x7eff5b0535f7 tensorrt_llm::batch_manager::kv_cache_manager::MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession&) + 1543
3       0x7eff5b073145 tensorrt_llm::batch_manager::DataResponder::Impl::response() + 2213
4       0x7eff5b06c81d std::_Function_handler<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> (), std::__future_base::_Task_setter<std::unique_ptr<std::__future_base::_Result<void>, std::__future_base::_Result_base::_Deleter>, std::thread::_Invoker<std::tuple<void (tensorrt_llm::batch_manager::DataResponder::Impl::*)() noexcept, tensorrt_llm::batch_manager::DataResponder::Impl*> >, void> >::_M_invoke(std::_Any_data const&) + 45
5       0x7eff5b04db5d std::__future_base::_State_baseV2::_M_do_set(std::function<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> ()>*, bool*) + 45
6       0x7f043360fed3 /usr/lib/x86_64-linux-gnu/libc.so.6(+0xa1ed3) [0x7f043360fed3]
7       0x7eff5b06dba8 std::__future_base::_Async_state_impl<std::thread::_Invoker<std::tuple<void (tensorrt_llm::batch_manager::DataResponder::Impl::*)() noexcept, tensorrt_llm::batch_manager::DataResponder::Impl*> >, void>::_M_run() + 248
8       0x7f01b0546db4 /usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xecdb4) [0x7f01b0546db4]
9       0x7f043360aaa4 /usr/lib/x86_64-linux-gnu/libc.so.6(+0x9caa4) [0x7f043360aaa4]
10      0x7f0433697c3c /usr/lib/x86_64-linux-gnu/libc.so.6(+0x129c3c) [0x7f0433697c3c] 
[TensorRT-LLM][ERROR] tensorrt_llm::common::TllmException: [TensorRT-LLM][ERROR] CUDA runtime error in ::cudaFreeAsync(ptr, mCudaStream->get()): an illegal memory access was encountered (/src/tensorrt_llm/cpp/tensorrt_llm/runtime/tllmBuffers.h:132)
1       0x7f4314827875 /usr/local/lib/python3.12/dist-packages/tensorrt_llm/libs/libtensorrt_llm.so(+0x1b33875) [0x7f4314827875]
2       0x7f4314c4ffd0 virtual thunk to tensorrt_llm::runtime::GenericTensor<tensorrt_llm::runtime::CudaAllocatorAsync>::~GenericTensor() + 144
3       0x7f4314bf0df2 std::vector<std::shared_ptr<tensorrt_llm::runtime::ITensor>, std::allocator<std::shared_ptr<tensorrt_llm::runtime::ITensor> > >::~vector() + 194
4       0x7f43148473dc /usr/local/lib/python3.12/dist-packages/tensorrt_llm/libs/libtensorrt_llm.so(+0x1b533dc) [0x7f43148473dc]
5       0x7f4315673145 tensorrt_llm::batch_manager::DataResponder::Impl::response() + 2213
6       0x7f431566c81d std::_Function_handler<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> (), std::__future_base::_Task_setter<std::unique_ptr<std::__future_base::_Result<void>, std::__future_base::_Result_base::_Deleter>, std::thread::_Invoker<std::tuple<void (tensorrt_llm::batch_manager::DataResponder::Impl::*)() noexcept, tensorrt_llm::batch_manager::DataResponder::Impl*> >, void> >::_M_invoke(std::_Any_data const&) + 45
7       0x7f431564db5d std::__future_base::_State_baseV2::_M_do_set(std::function<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> ()>*, bool*) + 45
8       0x7f47ee4b5ed3 /usr/lib/x86_64-linux-gnu/libc.so.6(+0xa1ed3) [0x7f47ee4b5ed3]
9       0x7f431566dba8 std::__future_base::_Async_state_impl<std::thread::_Invoker<std::tuple<void (tensorrt_llm::batch_manager::DataResponder::Impl::*)() noexcept, tensorrt_llm::batch_manager::DataResponder::Impl*> >, void>::_M_run() + 248
10      0x7f456a7a2db4 /usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xecdb4) [0x7f456a7a2db4]
11      0x7f47ee4b0aa4 /usr/lib/x86_64-linux-gnu/libc.so.6(+0x9caa4) [0x7f47ee4b0aa4]
12      0x7f47ee53dc3c /usr/lib/x86_64-linux-gnu/libc.so.6(+0x129c3c) [0x7f47ee53dc3c]
[TensorRT-LLM][ERROR] Exception in sendAndRemoveResponse: [TensorRT-LLM][ERROR] CUDA runtime error in ::cudaStreamSynchronize(get()): an illegal memory access was encountered (/src/tensorrt_llm/cpp/include/tensorrt_llm/runtime/cudaStream.h:83)
1       0x7f433bdd6b5b void tensorrt_llm::common::check<cudaError>(cudaError, char const*, char const*, int) + 139
2       0x7f43156535f7 tensorrt_llm::batch_manager::kv_cache_manager::MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession&) + 1543
3       0x7f4315673145 tensorrt_llm::batch_manager::DataResponder::Impl::response() + 2213
4       0x7f431566c81d std::_Function_handler<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> (), std::__future_base::_Task_setter<std::unique_ptr<std::__future_base::_Result<void>, std::__future_base::_Result_base::_Deleter>, std::thread::_Invoker<std::tuple<void (tensorrt_llm::batch_manager::DataResponder::Impl::*)() noexcept, tensorrt_llm::batch_manager::DataResponder::Impl*> >, void> >::_M_invoke(std::_Any_data const&) + 45
5       0x7f431564db5d std::__future_base::_State_baseV2::_M_do_set(std::function<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> ()>*, bool*) + 45
6       0x7f47ee4b5ed3 /usr/lib/x86_64-linux-gnu/libc.so.6(+0xa1ed3) [0x7f47ee4b5ed3]
7       0x7f431566dba8 std::__future_base::_Async_state_impl<std::thread::_Invoker<std::tuple<void (tensorrt_llm::batch_manager::DataResponder::Impl::*)() noexcept, tensorrt_llm::batch_manager::DataResponder::Impl*> >, void>::_M_run() + 248
8       0x7f456a7a2db4 /usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xecdb4) [0x7f456a7a2db4]
9       0x7f47ee4b0aa4 /usr/lib/x86_64-linux-gnu/libc.so.6(+0x9caa4) [0x7f47ee4b0aa4]
10      0x7f47ee53dc3c /usr/lib/x86_64-linux-gnu/libc.so.6(+0x129c3c) [0x7f47ee53dc3c] 

additional notes

When I warm up the server by send queries in of increasing length (1000, 10k, 20k, 40k, 80k, 100k), the bug does not appear. I suspect it has to do with attention workspace reallocation.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and checked the documentation and examples for answers to frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Customized kernels<NV>Specialized/modified CUDA kernels in TRTLLM for LLM ops, beyond standard TRT. Dev & perf.bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions