Skip to content

[Bug]: DeepGEMM does not work with CUDA Graph #19722

@chaunceyjiang

Description

@chaunceyjiang

Your current environment

The output of python collect_env.py
Collecting environment information...
==============================
        System Info
==============================
OS                           : Ubuntu 22.04.5 LTS (x86_64)
GCC version                  : (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version                : Could not collect
CMake version                : version 3.22.1
Libc version                 : glibc-2.35

==============================
       PyTorch Info
==============================
PyTorch version              : 2.7.0+cu126
Is debug build               : False
CUDA used to build PyTorch   : 12.6
ROCM used to build PyTorch   : N/A

==============================
      Python Environment
==============================
Python version               : 3.12.11 | packaged by Anaconda, Inc. | (main, Jun  5 2025, 13:09:17) [GCC 11.2.0] (64-bit runtime)
Python platform              : Linux-5.15.0-140-generic-x86_64-with-glibc2.35

==============================
       CUDA / GPU Info
==============================
Is CUDA available            : True
CUDA runtime version         : 12.6.85
CUDA_MODULE_LOADING set to   : LAZY
GPU models and configuration :
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3

Nvidia driver version        : 565.57.01
cuDNN version                : Could not collect
HIP runtime version          : N/A
MIOpen runtime version       : N/A
Is XNNPACK available         : True

==============================
          CPU Info
==============================
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        46 bits physical, 57 bits virtual
Byte Order:                           Little Endian
CPU(s):                               192
On-line CPU(s) list:                  0-191
Vendor ID:                            GenuineIntel
Model name:                           Intel(R) Xeon(R) Platinum 8468
CPU family:                           6
Model:                                143
Thread(s) per core:                   2
Core(s) per socket:                   48
Socket(s):                            2
Stepping:                             8
Frequency boost:                      enabled
CPU max MHz:                          2101.0000
CPU min MHz:                          800.0000
BogoMIPS:                             4200.00
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities
Virtualization:                       VT-x
L1d cache:                            4.5 MiB (96 instances)
L1i cache:                            3 MiB (96 instances)
L2 cache:                             192 MiB (96 instances)
L3 cache:                             210 MiB (2 instances)
NUMA node(s):                         2
NUMA node0 CPU(s):                    0-47,96-143
NUMA node1 CPU(s):                    48-95,144-191
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

==============================
Versions of relevant libraries
==============================
[pip3] numpy==2.2.6
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.5.1.17
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-cufile-cu12==1.11.1.6
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] pyzmq==26.4.0
[pip3] torch==2.7.0
[pip3] torchaudio==2.7.0
[pip3] torchvision==0.22.0
[pip3] transformers==4.52.4
[pip3] triton==3.3.0
[conda] Could not collect

==============================
         vLLM Info
==============================
ROCM Version                 : Could not collect
Neuron SDK Version           : N/A
vLLM Version                 : 0.9.2.dev57+gc68698b32 (git sha: c68698b32)
vLLM Build Flags:
  CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
  	GPU0	GPU1	GPU2	GPU3	GPU4	GPU5	GPU6	GPU7	NIC0	NIC1	NIC2	NIC3	NIC4	CPU Affinity	NUMA Affinity	GPU NUMA ID
GPU0	 X 	NV18	NV18	NV18	NV18	NV18	NV18	NV18	PIX	NODE	SYS	SYS	SYS	0-47,96-143	0		N/A
GPU1	NV18	 X 	NV18	NV18	NV18	NV18	NV18	NV18	NODE	PIX	SYS	SYS	SYS	0-47,96-143	0		N/A
GPU2	NV18	NV18	 X 	NV18	NV18	NV18	NV18	NV18	NODE	NODE	SYS	SYS	SYS	0-47,96-143	0		N/A
GPU3	NV18	NV18	NV18	 X 	NV18	NV18	NV18	NV18	NODE	NODE	SYS	SYS	SYS	0-47,96-143	0		N/A
GPU4	NV18	NV18	NV18	NV18	 X 	NV18	NV18	NV18	SYS	SYS	PIX	NODE	NODE	48-95,144-191	1		N/A
GPU5	NV18	NV18	NV18	NV18	NV18	 X 	NV18	NV18	SYS	SYS	NODE	PIX	NODE	48-95,144-191	1		N/A
GPU6	NV18	NV18	NV18	NV18	NV18	NV18	 X 	NV18	SYS	SYS	NODE	NODE	NODE	48-95,144-191	1		N/A
GPU7	NV18	NV18	NV18	NV18	NV18	NV18	NV18	 X 	SYS	SYS	NODE	NODE	PIX	48-95,144-191	1		N/A
NIC0	PIX	NODE	NODE	NODE	SYS	SYS	SYS	SYS	 X 	NODE	SYS	SYS	SYS
NIC1	NODE	PIX	NODE	NODE	SYS	SYS	SYS	SYS	NODE	 X 	SYS	SYS	SYS
NIC2	SYS	SYS	SYS	SYS	PIX	NODE	NODE	NODE	SYS	SYS	 X 	NODE	NODE
NIC3	SYS	SYS	SYS	SYS	NODE	PIX	NODE	NODE	SYS	SYS	NODE	 X 	NODE
NIC4	SYS	SYS	SYS	SYS	NODE	NODE	NODE	PIX	SYS	SYS	NODE	NODE	 X

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_0
  NIC1: mlx5_1
  NIC2: mlx5_2
  NIC3: mlx5_3
  NIC4: mlx5_bond_0

==============================
     Environment Variables
==============================
LD_LIBRARY_PATH=/opt/ucx/lib:/opt/nixl/lib/x86_64-linux-gnu/:/opt/nvshmem/lib:/usr/local/cuda/lib64:
VLLM_LOGGING_LEVEL=DEBUG
NCCL_CUMEM_ENABLE=0
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1
TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_root
VLLM_WORKER_MULTIPROC_METHOD=spawn
CUDA_MODULE_LOADING=LAZ

🐛 Describe the bug

VLLM_ALL2ALL_BACKEND="deepep_high_throughput" VLLM_USE_DEEP_GEMM=1 vllm serve /data/deepseek-ai/DeepSeek-R1 -tp 8 -pp 2 --enable-expert-paralle
(RayWorkerWrapper pid=3944622)
(RayWorkerWrapper pid=3173774, ip=10.254.20.30) INFO 06-17 02:49:25 [gpu_model_runner.py:1659] Model loading took 41.0417 GiB and 15.405849 seconds
(RayWorkerWrapper pid=3944613) DEBUG 06-17 02:49:26 [decorators.py:204] Start compiling function <code object forward at 0x28ce15c0, file "/data/kebe/vllm/vllm/model_executor/models/deepseek_v2.py", line 653>
ERROR 06-17 02:49:27 [core.py:516] EngineCore failed to start.
ERROR 06-17 02:49:27 [core.py:516] Traceback (most recent call last):
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/vllm/vllm/v1/engine/core.py", line 507, in run_engine_core
ERROR 06-17 02:49:27 [core.py:516]     engine_core = EngineCoreProc(*args, **kwargs)
ERROR 06-17 02:49:27 [core.py:516]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/vllm/vllm/v1/engine/core.py", line 391, in __init__
ERROR 06-17 02:49:27 [core.py:516]     super().__init__(vllm_config, executor_class, log_stats,
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/vllm/vllm/v1/engine/core.py", line 83, in __init__
ERROR 06-17 02:49:27 [core.py:516]     self._initialize_kv_caches(vllm_config)
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/vllm/vllm/v1/engine/core.py", line 141, in _initialize_kv_caches
ERROR 06-17 02:49:27 [core.py:516]     available_gpu_memory = self.model_executor.determine_available_memory()
ERROR 06-17 02:49:27 [core.py:516]                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/vllm/vllm/v1/executor/abstract.py", line 76, in determine_available_memory
ERROR 06-17 02:49:27 [core.py:516]     output = self.collective_rpc("determine_available_memory")
ERROR 06-17 02:49:27 [core.py:516]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/vllm/vllm/executor/executor_base.py", line 332, in collective_rpc
ERROR 06-17 02:49:27 [core.py:516]     return self._run_workers(method, *args, **(kwargs or {}))
ERROR 06-17 02:49:27 [core.py:516]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/vllm/vllm/executor/ray_distributed_executor.py", line 522, in _run_workers
ERROR 06-17 02:49:27 [core.py:516]     ray_worker_outputs = ray.get(ray_worker_outputs)
ERROR 06-17 02:49:27 [core.py:516]                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/conda/envs/vllm-dev/lib/python3.12/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
ERROR 06-17 02:49:27 [core.py:516]     return fn(*args, **kwargs)
ERROR 06-17 02:49:27 [core.py:516]            ^^^^^^^^^^^^^^^^^^^
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/conda/envs/vllm-dev/lib/python3.12/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
ERROR 06-17 02:49:27 [core.py:516]     return func(*args, **kwargs)
ERROR 06-17 02:49:27 [core.py:516]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/conda/envs/vllm-dev/lib/python3.12/site-packages/ray/_private/worker.py", line 2822, in get
ERROR 06-17 02:49:27 [core.py:516]     values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
ERROR 06-17 02:49:27 [core.py:516]                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/conda/envs/vllm-dev/lib/python3.12/site-packages/ray/_private/worker.py", line 930, in get_objects
ERROR 06-17 02:49:27 [core.py:516]     raise value.as_instanceof_cause()
ERROR 06-17 02:49:27 [core.py:516] ray.exceptions.RayTaskError(Unsupported): ray::RayWorkerWrapper.execute_method() (pid=3173771, ip=10.254.20.30, actor_id=33e1551b6ca3537112d3cf8601000000, repr=<vllm.executor.ray_utils.RayWorkerWrapper object at 0x7f9986b9df10>)
ERROR 06-17 02:49:27 [core.py:516]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-17 02:49:27 [core.py:516]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/vllm/vllm/worker/worker_base.py", line 623, in execute_method
ERROR 06-17 02:49:27 [core.py:516]     raise e
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/vllm/vllm/worker/worker_base.py", line 614, in execute_method
ERROR 06-17 02:49:27 [core.py:516]     return run_method(self, method, args, kwargs)
ERROR 06-17 02:49:27 [core.py:516]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/vllm/vllm/utils.py", line 2680, in run_method
ERROR 06-17 02:49:27 [core.py:516]     return func(*args, **kwargs)
ERROR 06-17 02:49:27 [core.py:516]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/conda/envs/vllm-dev/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 06-17 02:49:27 [core.py:516]     return func(*args, **kwargs)
ERROR 06-17 02:49:27 [core.py:516]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/vllm/vllm/v1/worker/gpu_worker.py", line 205, in determine_available_memory
ERROR 06-17 02:49:27 [core.py:516]     self.model_runner.profile_run()
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/vllm/vllm/v1/worker/gpu_model_runner.py", line 2048, in profile_run
ERROR 06-17 02:49:27 [core.py:516]     hidden_states = self._dummy_run(self.max_num_tokens)
ERROR 06-17 02:49:27 [core.py:516]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/conda/envs/vllm-dev/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 06-17 02:49:27 [core.py:516]     return func(*args, **kwargs)
ERROR 06-17 02:49:27 [core.py:516]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/vllm/vllm/v1/worker/gpu_model_runner.py", line 1883, in _dummy_run
ERROR 06-17 02:49:27 [core.py:516]     outputs = model(
ERROR 06-17 02:49:27 [core.py:516]               ^^^^^^
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/conda/envs/vllm-dev/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 06-17 02:49:27 [core.py:516]     return self._call_impl(*args, **kwargs)
ERROR 06-17 02:49:27 [core.py:516]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/conda/envs/vllm-dev/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 06-17 02:49:27 [core.py:516]     return forward_call(*args, **kwargs)
ERROR 06-17 02:49:27 [core.py:516]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/vllm/vllm/model_executor/models/deepseek_v2.py", line 714, in forward
ERROR 06-17 02:49:27 [core.py:516]     hidden_states = self.model(input_ids, positions, intermediate_tensors,
ERROR 06-17 02:49:27 [core.py:516]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/vllm/vllm/compilation/decorators.py", line 239, in __call__
ERROR 06-17 02:49:27 [core.py:516]     output = self.compiled_callable(*args, **kwargs)
ERROR 06-17 02:49:27 [core.py:516]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/conda/envs/vllm-dev/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 659, in _fn
ERROR 06-17 02:49:27 [core.py:516]     raise e.with_traceback(None) from None
ERROR 06-17 02:49:27 [core.py:516] torch._dynamo.exc.Unsupported: Attempted to call function marked as skipped
ERROR 06-17 02:49:27 [core.py:516]   Explanation: Dynamo does not know how to trace the builtin `None.cuuint64_t.__new__.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind).
ERROR 06-17 02:49:27 [core.py:516]   Hint: If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
ERROR 06-17 02:49:27 [core.py:516]   Hint: If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
ERROR 06-17 02:49:27 [core.py:516]
ERROR 06-17 02:49:27 [core.py:516]   Developer debug context: module: None, qualname: cuuint64_t.__new__, skip reason: <missing reason>
ERROR 06-17 02:49:27 [core.py:516]
ERROR 06-17 02:49:27 [core.py:516]
ERROR 06-17 02:49:27 [core.py:516] from user code:
ERROR 06-17 02:49:27 [core.py:516]    File "/data/kebe/vllm/vllm/model_executor/models/deepseek_v2.py", line 672, in forward
ERROR 06-17 02:49:27 [core.py:516]     hidden_states, residual = layer(positions, hidden_states, residual)
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/vllm/vllm/model_executor/models/deepseek_v2.py", line 592, in forward
ERROR 06-17 02:49:27 [core.py:516]     hidden_states = self.mlp(hidden_states)
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/vllm/vllm/model_executor/models/deepseek_v2.py", line 160, in forward
ERROR 06-17 02:49:27 [core.py:516]     final_hidden_states = self.experts(
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/vllm/vllm/model_executor/layers/fused_moe/layer.py", line 1357, in forward
ERROR 06-17 02:49:27 [core.py:516]     return self.forward_impl(hidden_states, router_logits)
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/vllm/vllm/model_executor/layers/fused_moe/layer.py", line 1449, in forward_impl
ERROR 06-17 02:49:27 [core.py:516]     final_hidden_states = self.quant_method.apply(
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/vllm/vllm/model_executor/layers/quantization/fp8.py", line 881, in apply
ERROR 06-17 02:49:27 [core.py:516]     return self.fused_experts(
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/vllm/vllm/model_executor/layers/fused_moe/fused_moe.py", line 1169, in fused_experts
ERROR 06-17 02:49:27 [core.py:516]     return deep_gemm_moe_fp8(
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/vllm/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py", line 222, in deep_gemm_moe_fp8
ERROR 06-17 02:49:27 [core.py:516]     return fn(
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/vllm/vllm/model_executor/layers/fused_moe/modular_kernel.py", line 439, in forward
ERROR 06-17 02:49:27 [core.py:516]     self.fused_experts.apply(
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/vllm/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py", line 147, in apply
ERROR 06-17 02:49:27 [core.py:516]     dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/DeepGEMM/deep_gemm/jit_kernels/m_grouped_gemm.py", line 74, in m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
ERROR 06-17 02:49:27 [core.py:516]     tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedContiguous, lhs, m, k, k, block_m, block_k, num_groups)
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/DeepGEMM/deep_gemm/jit_kernels/runtime.py", line 96, in make_2d_tma_a_desc
ERROR 06-17 02:49:27 [core.py:516]     return make_2d_tma_desc(t,
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/DeepGEMM/deep_gemm/jit_kernels/runtime.py", line 87, in make_2d_tma_desc
ERROR 06-17 02:49:27 [core.py:516]     gmem_dim = (cbd.cuuint64_t(gmem_inner_dim), cbd.cuuint64_t(gmem_outer_dim))
ERROR 06-17 02:49:27 [core.py:516]   File "/data/kebe/conda/envs/vllm-dev/lib/python3.12/site-packages/torch/_dynamo/polyfills/__init__.py", line 151, in instantiate_user_defined_class_object
ERROR 06-17 02:49:27 [core.py:516]     obj = cls.__new__(cls, *args, **kwargs)
ERROR 06-17 02:49:27 [core.py:516]
ERROR 06-17 02:49:27 [core.py:516] Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
Process EngineCore_0:
Traceback (most recent call last):
  File "/data/kebe/conda/envs/vllm-dev/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/data/kebe/conda/envs/vllm-dev/lib/python3.12/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/data/kebe/vllm/vllm/v1/engine/core.py", line 520, in run_engine_core
    raise e
  File "/data/kebe/vllm/vllm/v1/engine/core.py", line 507, in run_engine_core
    engine_core = EngineCoreProc(*args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/kebe/vllm/vllm/v1/engine/core.py", line 391, in __init__
    super().__init__(vllm_config, executor_class, log_stats,
  File "/data/kebe/vllm/vllm/v1/engine/core.py", line 83, in __init__
    self._initialize_kv_caches(vllm_config)
  File "/data/kebe/vllm/vllm/v1/engine/core.py", line 141, in _initialize_kv_caches
    available_gpu_memory = self.model_executor.determine_available_memory()
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/kebe/vllm/vllm/v1/executor/abstract.py", line 76, in determine_available_memory
    output = self.collective_rpc("determine_available_memory")
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/kebe/vllm/vllm/executor/executor_base.py", line 332, in collective_rpc
    return self._run_workers(method, *args, **(kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/kebe/vllm/vllm/executor/ray_distributed_executor.py", line 522, in _run_workers
    ray_worker_outputs = ray.get(ray_worker_outputs)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/kebe/conda/envs/vllm-dev/lib/python3.12/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/kebe/conda/envs/vllm-dev/lib/python3.12/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data/kebe/conda/envs/vllm-dev/lib/python3.12/site-packages/ray/_private/worker.py", line 2822, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/kebe/conda/envs/vllm-dev/lib/python3.12/site-packages/ray/_private/worker.py", line 930, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(Unsupported): ray::RayWorkerWrapper.execute_method() (pid=3173771, ip=10.254.20.30, actor_id=33e1551b6ca3537112d3cf8601000000, repr=<vllm.executor.ray_utils.RayWorkerWrapper object at 0x7f9986b9df10>)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/kebe/vllm/vllm/worker/worker_base.py", line 623, in execute_method
    raise e
  File "/data/kebe/vllm/vllm/worker/worker_base.py", line 614, in execute_method
    return run_method(self, method, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/kebe/vllm/vllm/utils.py", line 2680, in run_method
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data/kebe/conda/envs/vllm-dev/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data/kebe/vllm/vllm/v1/worker/gpu_worker.py", line 205, in determine_available_memory
    self.model_runner.profile_run()
  File "/data/kebe/vllm/vllm/v1/worker/gpu_model_runner.py", line 2048, in profile_run
    hidden_states = self._dummy_run(self.max_num_tokens)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/kebe/conda/envs/vllm-dev/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data/kebe/vllm/vllm/v1/worker/gpu_model_runner.py", line 1883, in _dummy_run
    outputs = model(
              ^^^^^^
  File "/data/kebe/conda/envs/vllm-dev/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/kebe/conda/envs/vllm-dev/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/kebe/vllm/vllm/model_executor/models/deepseek_v2.py", line 714, in forward
    hidden_states = self.model(input_ids, positions, intermediate_tensors,
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/kebe/vllm/vllm/compilation/decorators.py", line 239, in __call__
    output = self.compiled_callable(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/kebe/conda/envs/vllm-dev/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 659, in _fn
    raise e.with_traceback(None) from None
torch._dynamo.exc.Unsupported: Attempted to call function marked as skipped
  Explanation: Dynamo does not know how to trace the builtin `None.cuuint64_t.__new__.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind).
  Hint: If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
  Hint: If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.

  Developer debug context: module: None, qualname: cuuint64_t.__new__, skip reason: <missing reason>


from user code:
   File "/data/kebe/vllm/vllm/model_executor/models/deepseek_v2.py", line 672, in forward
    hidden_states, residual = layer(positions, hidden_states, residual)
  File "/data/kebe/vllm/vllm/model_executor/models/deepseek_v2.py", line 592, in forward
    hidden_states = self.mlp(hidden_states)
  File "/data/kebe/vllm/vllm/model_executor/models/deepseek_v2.py", line 160, in forward
    final_hidden_states = self.experts(
  File "/data/kebe/vllm/vllm/model_executor/layers/fused_moe/layer.py", line 1357, in forward
    return self.forward_impl(hidden_states, router_logits)
  File "/data/kebe/vllm/vllm/model_executor/layers/fused_moe/layer.py", line 1449, in forward_impl
    final_hidden_states = self.quant_method.apply(
  File "/data/kebe/vllm/vllm/model_executor/layers/quantization/fp8.py", line 881, in apply
    return self.fused_experts(
  File "/data/kebe/vllm/vllm/model_executor/layers/fused_moe/fused_moe.py", line 1169, in fused_experts
    return deep_gemm_moe_fp8(
  File "/data/kebe/vllm/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py", line 222, in deep_gemm_moe_fp8
    return fn(
  File "/data/kebe/vllm/vllm/model_executor/layers/fused_moe/modular_kernel.py", line 439, in forward
    self.fused_experts.apply(
  File "/data/kebe/vllm/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py", line 147, in apply
    dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
  File "/data/kebe/DeepGEMM/deep_gemm/jit_kernels/m_grouped_gemm.py", line 74, in m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
    tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedContiguous, lhs, m, k, k, block_m, block_k, num_groups)
  File "/data/kebe/DeepGEMM/deep_gemm/jit_kernels/runtime.py", line 96, in make_2d_tma_a_desc
    return make_2d_tma_desc(t,
  File "/data/kebe/DeepGEMM/deep_gemm/jit_kernels/runtime.py", line 87, in make_2d_tma_desc
    gmem_dim = (cbd.cuuint64_t(gmem_inner_dim), cbd.cuuint64_t(gmem_outer_dim))
  File "/data/kebe/conda/envs/vllm-dev/lib/python3.12/site-packages/torch/_dynamo/polyfills/__init__.py", line 151, in instantiate_user_defined_class_object
    obj = cls.__new__(cls, *args, **kwargs)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
2025-06-17 02:49:27,809	ERROR worker.py:421 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ray::RayWorkerWrapper.execute_method() (pid=3173777, ip=10.254.20.30, actor_id=126e75153843d5454ae507a801000000, repr=<vllm.executor.ray_utils.RayWorkerWrapper object at 0x7f87ffe81cd0>)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/kebe/vllm/vllm/worker/worker_base.py", line 623, in execute_method
    raise e
  File "/data/kebe/vllm/vllm/worker/worker_base.py", line 614, in execute_method
    return run_method(self, method, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/kebe/vllm/vllm/utils.py", line 2680, in run_method
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data/kebe/conda/envs/vllm-dev/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data/kebe/vllm/vllm/v1/worker/gpu_worker.py", line 205, in determine_available_memory
    self.model_runner.profile_run()
  File "/data/kebe/vllm/vllm/v1/worker/gpu_model_runner.py", line 2048, in profile_run
    hidden_states = self._dummy_run(self.max_num_tokens)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/kebe/conda/envs/vllm-dev/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data/kebe/vllm/vllm/v1/worker/gpu_model_runner.py", line 1883, in _dummy_run
    outputs = model(
              ^^^^^^
  File "/data/kebe/conda/envs/vllm-dev/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/kebe/conda/envs/vllm-dev/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/kebe/vllm/vllm/model_executor/models/deepseek_v2.py", line 714, in forward
    hidden_states = self.model(input_ids, positions, intermediate_tensors,
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/kebe/vllm/vllm/compilation/decorators.py", line 239, in __call__
    output = self.compiled_callable(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/kebe/conda/envs/vllm-dev/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 659, in _fn
    raise e.with_traceback(None) from None
torch._dynamo.exc.Unsupported: Attempted to call function marked as skipped
  Explanation: Dynamo does not know how to trace the builtin `None.cuuint64_t.__new__.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind).

It works correctly if we use this command.

VLLM_ALL2ALL_BACKEND="deepep_high_throughput" VLLM_USE_DEEP_GEMM=1 vllm serve /data/deepseek-ai/DeepSeek-R1 -tp 8 -pp 2 --enable-expert-parallel --enforce-eager
INFO 06-17 03:06:48 [launcher.py:37] Route: /v1/rerank, Methods: POST
INFO 06-17 03:06:48 [launcher.py:37] Route: /v2/rerank, Methods: POST
INFO 06-17 03:06:48 [launcher.py:37] Route: /invocations, Methods: POST
INFO 06-17 03:06:48 [launcher.py:37] Route: /metrics, Methods: GET
INFO:     Started server process [3967552]
INFO:     Waiting for application startup.
INFO:     Application startup complete.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions