Skip to content

Error in dispatch_kv_cache_events_thread: 'NoneType' object has no attribute 'new_value' #4666

@Rb-Ach

Description

@Rb-Ach

System Info

Description
When I enable KV-cache events (via a nonzero event_buffer_max_size) and run inference, the background event‐dispatch thread crashes and then the generate() call itself raises a RequestError. The root cause appears to be that one of the event diffs has priority=None, and the serializer assumes data.new_value always exists.

Environment
• tensorrt-llm version: 0.20.0rc3
• Python: 3.10.12
• CUDA Toolkit: 12.8
• GPU: NVIDIA A100
• OS: Ubuntu 22.04 LTS

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

`from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm.llmapi import KvCacheConfig
from transformers import AutoTokenizer
from datasets import load_from_disk
import json

def main():
pytorch_config = PyTorchConfig(autotuner_enabled=False, kv_cache_dtype="auto")
kv_cache_cfg = KvCacheConfig(
enable_block_reuse=True,
event_buffer_max_size=1024*1024, # > 0 to spawn the dispatch thread
host_cache_size=20_000_000_000
)
llm = LLM(
model="mistralai/Mistral-7B-v0.1",
pytorch_backend_config=pytorch_config,
kv_cache_config=kv_cache_cfg
)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True)
sampling_params = SamplingParams(temperature=0.001, top_p=0.001, max_tokens=1)

# single prompt to trigger at least one event
prompt = "Hello world"
tokens = tokenizer(prompt, return_tensors=None)["input_ids"]
llm.generate(tokens, sampling_params=sampling_params)

# Attempt to pull events
events = llm.get_kv_cache_events(timeout=0.1)
print(events)

llm.close()

if name == "main":
main()`

Expected behavior

generate() loop should complete cleanly.

actual behavior

:1184: FutureWarning: The cuda.cuda module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.driver module instead.
:1184: FutureWarning: The cuda.cudart module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.runtime module instead.
/home/rebaia/miniforge3/envs/rt/lib/python3.10/site-packages/transformers/utils/hub.py:105: FutureWarning: Using TRANSFORMERS_CACHE is deprecated and will be removed in v5 of Transformers. Use HF_HOME instead.
warnings.warn(
2025-05-26 13:18:44,874 - INFO - flashinfer.jit: Prebuilt kernels not found, using JIT backend
[TensorRT-LLM] TensorRT-LLM version: 0.20.0rc3
→ Loading dataset …
rank 0 using MpiPoolSession to spawn MPI processes
:1184: FutureWarning: The cuda.cuda module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.driver module instead.
:1184: FutureWarning: The cuda.cudart module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.runtime module instead.
/home/rebaia/miniforge3/envs/rt/lib/python3.10/site-packages/transformers/utils/hub.py:105: FutureWarning: Using TRANSFORMERS_CACHE is deprecated and will be removed in v5 of Transformers. Use HF_HOME instead.
warnings.warn(
2025-05-26 13:18:59,479 - INFO - flashinfer.jit: Prebuilt kernels not found, using JIT backend
[TensorRT-LLM] TensorRT-LLM version: 0.20.0rc3
[TensorRT-LLM][INFO] Refreshed the MPI local session
[TensorRT-LLM][INFO] Engine version 0.20.0rc3 found in the config file, assuming engine(s) built by new builder API.
[TensorRT-LLM][INFO] Refreshed the MPI local session
[TensorRT-LLM][INFO] MPI size: 1, MPI local size: 1, rank: 0
[TensorRT-LLM][INFO] Rank 0 is using GPU 0
[TensorRT-LLM][INFO] TRTGptModel maxNumSequences: 2048
[TensorRT-LLM][INFO] TRTGptModel maxBatchSize: 2048
[TensorRT-LLM][INFO] TRTGptModel maxBeamWidth: 1
[TensorRT-LLM][INFO] TRTGptModel maxSequenceLen: 32768
[TensorRT-LLM][INFO] TRTGptModel maxDraftLen: 0
[TensorRT-LLM][INFO] TRTGptModel mMaxAttentionWindowSize: (32768) * 32
[TensorRT-LLM][INFO] TRTGptModel enableTrtOverlap: 0
[TensorRT-LLM][INFO] TRTGptModel normalizeLogProbs: 0
[TensorRT-LLM][INFO] TRTGptModel maxNumTokens: 8192
[TensorRT-LLM][INFO] TRTGptModel maxInputLen: 8192 = min(maxSequenceLen - 1, maxNumTokens) since context FMHA and usePackedInput are enabled
[TensorRT-LLM][INFO] TRTGptModel If model type is encoder, maxInputLen would be reset in trtEncoderModel to maxInputLen: min(maxSequenceLen, maxNumTokens).
[TensorRT-LLM][INFO] Capacity Scheduler Policy: GUARANTEED_NO_EVICT
[TensorRT-LLM][INFO] Context Chunking Scheduler Policy: None
[TensorRT-LLM][INFO] Loaded engine size: 13854 MiB
[TensorRT-LLM][INFO] Engine load time 10066 ms
[TensorRT-LLM][INFO] Inspecting the engine to identify potential runtime issues...
[TensorRT-LLM][INFO] The profiling verbosity of the engine does not allow this analysis to proceed. Re-build the engine with 'detailed' profiling verbosity to get more diagnostics.
[TensorRT-LLM][INFO] [MemUsageChange] Allocated 1152.01 MiB for execution context memory.
[TensorRT-LLM][INFO] gatherContextLogits: 0
[TensorRT-LLM][INFO] gatherGenerationLogits: 0
[TensorRT-LLM][INFO] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 13828 (MiB)
[TensorRT-LLM][INFO] [MemUsageChange] Allocated 1.49 GB GPU memory for runtime buffers.
[TensorRT-LLM][INFO] [MemUsageChange] Allocated 2.10 GB GPU memory for decoder.
[TensorRT-LLM][INFO] Memory usage when calculating max tokens in paged kv cache: total: 39.39 GiB, available: 20.19 GiB, extraCostMemory: 0.00 GiB
[TensorRT-LLM][INFO] Number of blocks in KV cache primary pool: 4653
[TensorRT-LLM][INFO] Number of blocks in KV cache secondary pool: 4768, onboard blocks to primary memory before reuse: true
[TensorRT-LLM][INFO] before Create KVCacheManager cacheTransPreAllocaSize:0
[TensorRT-LLM][INFO] Max KV cache pages per sequence: 1024 [window size=32768]
[TensorRT-LLM][INFO] Number of tokens per block: 32.
[TensorRT-LLM][INFO] [MemUsageChange] Allocated 18.18 GiB for max tokens in paged KV cache (148896).
→ Prompt 0: 335 tokens
Processed requests: 100%|██████████| 1/1 [00:00<00:00, 1.71it/s]
→ Prompt 1: 22 tokens
Processed requests: 100%|██████████| 1/1 [00:00<00:00, 13.04it/s]
→ Prompt 2: 28 tokens
Processed requests: 100%|██████████| 1/1 [00:00<00:00, 13.17it/s]
→ Prompt 3: 43 tokens
Processed requests: 100%|██████████| 1/1 [00:00<00:00, 13.10it/s]

Processed requests: 100%|██████████| 1/1 [00:00<00:00, 13.38it/s]
→ Prompt 1181: 591 tokens
Processed requests: 100%|██████████| 1/1 [00:00<00:00, 4.83it/s]
→ Prompt 1182: 8 tokens
Processed requests: 100%|██████████| 1/1 [00:00<00:00, 13.36it/s]
→ Prompt 1183: 2 tokens
[05/26/2025-13:21:17] [TRT-LLM] [E] worker.py: Error in _iteration_result_task: 'NoneType' object has no attribute 'new_value'
[05/26/2025-13:21:17] [TRT-LLM] [E] Error in thread dispatch_kv_cache_events_thread: 'NoneType' object has no attribute 'new_value'
Traceback (most recent call last):
File "/home/rebaia/miniforge3/envs/rt/lib/python3.10/site-packages/tensorrt_llm/llmapi/utils.py", line 267, in run
if not task(**self.kwargs):
File "/home/rebaia/miniforge3/envs/rt/lib/python3.10/site-packages/tensorrt_llm/executor/worker.py", line 313, in dispatch_kv_cache_events_task
return self._iteration_result_task(
File "/home/rebaia/miniforge3/envs/rt/lib/python3.10/site-packages/tensorrt_llm/executor/worker.py", line 270, in _iteration_result_task
raise e
File "/home/rebaia/miniforge3/envs/rt/lib/python3.10/site-packages/tensorrt_llm/executor/worker.py", line 248, in _iteration_result_task
res = result_serializer(results)
File "/home/rebaia/miniforge3/envs/rt/lib/python3.10/site-packages/tensorrt_llm/executor/worker.py", line 315, in
lambda x: json.dumps(KVCacheEventSerializer.serialize(x)))
File "/home/rebaia/miniforge3/envs/rt/lib/python3.10/site-packages/tensorrt_llm/_utils.py", line 942, in serialize
return cls.to_json_str(events)
File "/home/rebaia/miniforge3/envs/rt/lib/python3.10/site-packages/tensorrt_llm/_utils.py", line 958, in to_json_str
"data": event_serialize_func(event.data),
File "/home/rebaia/miniforge3/envs/rt/lib/python3.10/site-packages/tensorrt_llm/_utils.py", line 1013, in _updated_to_json
KVCacheEventSerializer._event_diff_to_json(data.priority)
File "/home/rebaia/miniforge3/envs/rt/lib/python3.10/site-packages/tensorrt_llm/_utils.py", line 1020, in _event_diff_to_json
"new_value": data.new_value,
AttributeError: 'NoneType' object has no attribute 'new_value'

Processed requests: 0%| | 0/1 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/home/rebaia/gpu_overlay/tensorrt/test.py", line 82, in
main()
File "/home/rebaia/gpu_overlay/tensorrt/test.py", line 62, in main
outputs = llm.generate(token_ids, sampling_params=sampling_params)
File "/home/rebaia/miniforge3/envs/rt/lib/python3.10/site-packages/tensorrt_llm/llmapi/llm.py", line 253, in generate
future.result()
File "/home/rebaia/miniforge3/envs/rt/lib/python3.10/site-packages/tensorrt_llm/executor/result.py", line 477, in result
self._result_step(timeout)
File "/home/rebaia/miniforge3/envs/rt/lib/python3.10/site-packages/tensorrt_llm/executor/result.py", line 459, in _result_step
self._handle_response(response)
File "/home/rebaia/miniforge3/envs/rt/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/rebaia/miniforge3/envs/rt/lib/python3.10/site-packages/tensorrt_llm/executor/result.py", line 354, in _handle_response
GenerationResultBase._handle_response(self, response)
File "/home/rebaia/miniforge3/envs/rt/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/rebaia/miniforge3/envs/rt/lib/python3.10/site-packages/tensorrt_llm/executor/result.py", line 324, in _handle_response
handler(response.error_msg)
File "/home/rebaia/miniforge3/envs/rt/lib/python3.10/site-packages/tensorrt_llm/executor/executor.py", line 260, in _handle_background_error
raise RequestError(error)
tensorrt_llm.executor.utils.RequestError: 'NoneType' object has no attribute 'new_value'
[05/26/2025-13:21:18] [TRT-LLM] [E] worker.py: Error in _iteration_result_task: 'NoneType' object has no attribute 'new_value'
[05/26/2025-13:21:18] [TRT-LLM] [E] Error in thread dispatch_kv_cache_events_thread: 'NoneType' object has no attribute 'new_value'
Traceback (most recent call last):
File "/home/rebaia/miniforge3/envs/rt/lib/python3.10/site-packages/tensorrt_llm/llmapi/utils.py", line 267, in run
if not task(**self.kwargs):
File "/home/rebaia/miniforge3/envs/rt/lib/python3.10/site-packages/tensorrt_llm/executor/worker.py", line 313, in dispatch_kv_cache_events_task
return self._iteration_result_task(
File "/home/rebaia/miniforge3/envs/rt/lib/python3.10/site-packages/tensorrt_llm/executor/worker.py", line 270, in _iteration_result_task
raise e
File "/home/rebaia/miniforge3/envs/rt/lib/python3.10/site-packages/tensorrt_llm/executor/worker.py", line 248, in _iteration_result_task
res = result_serializer(results)
File "/home/rebaia/miniforge3/envs/rt/lib/python3.10/site-packages/tensorrt_llm/executor/worker.py", line 315, in
lambda x: json.dumps(KVCacheEventSerializer.serialize(x)))
File "/home/rebaia/miniforge3/envs/rt/lib/python3.10/site-packages/tensorrt_llm/_utils.py", line 942, in serialize
return cls.to_json_str(events)
File "/home/rebaia/miniforge3/envs/rt/lib/python3.10/site-packages/tensorrt_llm/_utils.py", line 958, in to_json_str
"data": event_serialize_func(event.data),
File "/home/rebaia/miniforge3/envs/rt/lib/python3.10/site-packages/tensorrt_llm/_utils.py", line 1013, in _updated_to_json
KVCacheEventSerializer._event_diff_to_json(data.priority)
File "/home/rebaia/miniforge3/envs/rt/lib/python3.10/site-packages/tensorrt_llm/_utils.py", line 1020, in _event_diff_to_json
"new_value": data.new_value,
AttributeError: 'NoneType' object has no attribute 'new_value'

additional notes

The script is running on an allocated node using Slurm, then:
`export LD_LIBRARY_PATH=$HOME/nvidia_compat/usr/local/cuda-12.8/compat
export TORCH_CUDA_ARCH_LIST="8.0;8.6"
#export TLLM_LOG_LEVEL=DEBUG
#export NCCL_DEBUG=INFO
#export TLLM_KV_EVENT_DEBUG=1

mpirun --oversubscribe -n 1 python test.py`

Metadata

Metadata

Assignees

No one assigned

    Labels

    Inference runtime<NV>General operational aspects of TRTLLM execution not in other categories.Infra<NV>automated tests, build checks, github actions, system stability & efficiency.bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions