From 62996028106d20680e42f3faffdc326a0cc52140 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Thu, 2 Oct 2025 07:15:27 +0000 Subject: [PATCH 1/9] Support PaddlePaddle with compatible API and tvm-ffi --- flashinfer/fp4_quantization.py | 9 +++++--- flashinfer/fused_moe/core.py | 41 +++++++++++++++++++++++++--------- flashinfer/jit/core.py | 4 ++++ flashinfer/jit/cpp_ext.py | 5 ++++- flashinfer/utils.py | 21 +++++++++++------ 5 files changed, 59 insertions(+), 21 deletions(-) diff --git a/flashinfer/fp4_quantization.py b/flashinfer/fp4_quantization.py index 24adbc76ff..b94a9720a4 100644 --- a/flashinfer/fp4_quantization.py +++ b/flashinfer/fp4_quantization.py @@ -180,7 +180,8 @@ def fp4_quantize_sm100( - Scale factors tensor with shape determined by layout and sf_vec_size """ if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + # enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) out_val = torch.empty( (*input.shape[:-1], input.shape[-1] // 2), dtype=torch.uint8, @@ -567,9 +568,11 @@ def fp4_quantize( assert input.shape[-1] % sf_vec_size == 0 if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + # enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) # get input device sm version - major, minor = get_compute_capability(input.device) + # major, minor = get_compute_capability(input.device) + major, minor = get_compute_capability(input.place) x_q, sf = get_fp4_quantization_module(f"{major}{minor}").fp4_quantize_sm100( input, global_scale, diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 1b78275719..0f96f5f712 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -20,6 +20,10 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch +import paddle + +with paddle.compat.use_torch_proxy_guard(enable=False): + import tvm_ffi from ..autotuner import ( AutoTuner, @@ -331,11 +335,15 @@ def __init__( use_mxfp8_act_scaling, ) + def paddle_dtype_to_tvm_ffi_dtype(dtype: paddle.dtype): + dtype_str = str(dtype).split(".", 1)[-1] + return tvm_ffi.dtype(dtype_str) + if instance_key not in MoERunner.runner_dict: MoERunner.runner_dict[instance_key] = module.init( - x_dtype, - weight_dtype, - output_dtype, + paddle_dtype_to_tvm_ffi_dtype(x_dtype), + paddle_dtype_to_tvm_ffi_dtype(weight_dtype), + paddle_dtype_to_tvm_ffi_dtype(output_dtype), use_deepseek_fp8_block_scale, use_w4_group_scaling, use_mxfp8_act_scaling, @@ -433,7 +441,8 @@ def cutlass_fused_moe( enable_pdl: Optional[bool] = None, ) -> List[torch.Tensor]: if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + # enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) tuner = AutoTuner.get() MoERunner.refine_tuning_config(tune_max_num_tokens) @@ -491,17 +500,22 @@ def cutlass_fused_moe( else moe_runner.fused_moe_runner.run_moe ) num_active_experts_per_node = torch.empty( - (1,), dtype=torch.int32, device=input.device + # (1,), dtype=torch.int32, device=input.device + (1,), + dtype=torch.int32, + device=input.place, ) experts_to_token_score = torch.empty( (fc2_expert_weights.shape[0], input.shape[0]), dtype=torch.float32, - device=input.device, + # device=input.device, + device=input.place, ) active_expert_global_ids = torch.empty( (fc2_expert_weights.shape[0],), dtype=torch.int32, - device=input.device, + # device=input.device, + device=input.place, ) min_latency_output = ( [ @@ -772,7 +786,8 @@ def cutlass_fused_moe( ) if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + # enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) num_rows = input.shape[0] if min_latency_mode: @@ -781,10 +796,16 @@ def cutlass_fused_moe( output_shape = (num_rows, hidden_size) if output is None: - output = torch.empty(output_shape, dtype=output_dtype, device=input.device) + # output = torch.empty(output_shape, dtype=output_dtype, device=input.device) + output = torch.empty(output_shape, dtype=output_dtype, device=input.place) else: check_shape_dtype_device( - output, output_shape, output_dtype, input.device, "output" + # output, output_shape, output_dtype, input.device, "output" + output, + output_shape, + output_dtype, + input.place, + "output", ) return get_cutlass_fused_moe_module(device_arch).cutlass_fused_moe( diff --git a/flashinfer/jit/core.py b/flashinfer/jit/core.py index a51597b9bd..992e401ffe 100644 --- a/flashinfer/jit/core.py +++ b/flashinfer/jit/core.py @@ -1,6 +1,10 @@ import dataclasses import logging import os +import paddle + +with paddle.compat.use_torch_proxy_guard(enable=False): + import tvm_ffi from contextlib import nullcontext from datetime import datetime from pathlib import Path diff --git a/flashinfer/jit/cpp_ext.py b/flashinfer/jit/cpp_ext.py index fb0c40c00e..562774abf7 100644 --- a/flashinfer/jit/cpp_ext.py +++ b/flashinfer/jit/cpp_ext.py @@ -10,7 +10,10 @@ from pathlib import Path from typing import List, Optional -import tvm_ffi +import paddle + +with paddle.compat.use_torch_proxy_guard(enable=False): + import tvm_ffi import torch from . import env as jit_env diff --git a/flashinfer/utils.py b/flashinfer/utils.py index e015010c83..2ea1a155e1 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -16,13 +16,12 @@ import functools import math +import os from enum import Enum from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple, Union import torch import torch.version -from torch.torch_version import TorchVersion -from torch.torch_version import __version__ as torch_version from .jit.spdlog import gen_spdlog_module @@ -243,6 +242,7 @@ def canonicalize_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype: @functools.cache def get_compute_capability(device: torch.device) -> Tuple[int, int]: + return torch.device.cuda.get_device_capability(device.gpu_device_id()) if device.type != "cuda": raise ValueError("device must be a cuda device") return torch.cuda.get_device_capability(device.index) @@ -261,7 +261,13 @@ def _check_cached_qkv_data_type( ) -if TorchVersion(torch_version) < TorchVersion("2.4"): +def use_paddle_compatible_api() -> bool: + return os.environ.get("PADDLE_COMPATIBLE_API", "0").lower() in ["1", "on", "true"] + + +if use_paddle_compatible_api() or torch.torch_version.TorchVersion( + torch.torch_version.__version__ +) < torch.torch_version.TorchVersion("2.4"): def register_custom_op( name: str, @@ -516,7 +522,7 @@ def check_shape_dtype_device( expected_device: Optional[torch.device], name: str, ) -> None: - if expected_shape and x.shape != torch.Size(expected_shape): + if expected_shape and tuple(x.shape) != torch.Size(expected_shape): raise ValueError( f"Invalid shape of {name}: expected {expected_shape}, got {x.shape}" ) @@ -524,7 +530,8 @@ def check_shape_dtype_device( raise ValueError( f"Invalid dtype of {name}: expected {expected_dtype}, got {x.dtype}" ) - if expected_device and x.device != expected_device: + # if expected_device and x.device != expected_device: + if expected_device and x.place != expected_device: raise ValueError( f"Invalid device of {name}: expected {expected_device}, got {x.device}" ) @@ -560,8 +567,8 @@ def set_log_level(lvl_str: str) -> None: @functools.cache def device_support_pdl(device: torch.device) -> bool: - if device.type != "cuda": - return False + # if device.type != "cuda": + # return False major, _ = get_compute_capability(device) return major >= 9 From eb73e797b593a88169f7bc52f31d9ee21b1647f4 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Mon, 13 Oct 2025 02:36:22 +0000 Subject: [PATCH 2/9] remove torch from requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a4e391c38d..c993a6b70c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,5 +9,5 @@ nvidia-ml-py packaging>=24.2 requests tabulate -torch +# torch tqdm From a038d3856dc821283dea6215cf291254b99a3140 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Thu, 23 Oct 2025 05:01:38 +0000 Subject: [PATCH 3/9] remove changes about import tvm_ffi --- flashinfer/fused_moe/core.py | 8 +++----- flashinfer/jit/core.py | 4 ---- flashinfer/jit/cpp_ext.py | 5 +---- 3 files changed, 4 insertions(+), 13 deletions(-) diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 0f96f5f712..2dab7ca49a 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -20,10 +20,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch -import paddle - -with paddle.compat.use_torch_proxy_guard(enable=False): - import tvm_ffi from ..autotuner import ( AutoTuner, @@ -335,7 +331,9 @@ def __init__( use_mxfp8_act_scaling, ) - def paddle_dtype_to_tvm_ffi_dtype(dtype: paddle.dtype): + def paddle_dtype_to_tvm_ffi_dtype(dtype): + import tvm_ffi + dtype_str = str(dtype).split(".", 1)[-1] return tvm_ffi.dtype(dtype_str) diff --git a/flashinfer/jit/core.py b/flashinfer/jit/core.py index 992e401ffe..a51597b9bd 100644 --- a/flashinfer/jit/core.py +++ b/flashinfer/jit/core.py @@ -1,10 +1,6 @@ import dataclasses import logging import os -import paddle - -with paddle.compat.use_torch_proxy_guard(enable=False): - import tvm_ffi from contextlib import nullcontext from datetime import datetime from pathlib import Path diff --git a/flashinfer/jit/cpp_ext.py b/flashinfer/jit/cpp_ext.py index 562774abf7..fb0c40c00e 100644 --- a/flashinfer/jit/cpp_ext.py +++ b/flashinfer/jit/cpp_ext.py @@ -10,10 +10,7 @@ from pathlib import Path from typing import List, Optional -import paddle - -with paddle.compat.use_torch_proxy_guard(enable=False): - import tvm_ffi +import tvm_ffi import torch from . import env as jit_env From fbefce5f7b7422a4045ab303a31323a81911e1d5 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Thu, 2 Oct 2025 07:15:27 +0000 Subject: [PATCH 4/9] Support PaddlePaddle with compatible API and tvm-ffi --- flashinfer/fp4_quantization.py | 9 +++++--- flashinfer/fused_moe/core.py | 41 +++++++++++++++++++++++++--------- flashinfer/jit/core.py | 4 ++++ flashinfer/jit/cpp_ext.py | 5 ++++- flashinfer/utils.py | 21 +++++++++++------ 5 files changed, 59 insertions(+), 21 deletions(-) diff --git a/flashinfer/fp4_quantization.py b/flashinfer/fp4_quantization.py index 29127f06ac..6f238382c4 100644 --- a/flashinfer/fp4_quantization.py +++ b/flashinfer/fp4_quantization.py @@ -180,7 +180,8 @@ def fp4_quantize_sm100( - Scale factors tensor with shape determined by layout and sf_vec_size """ if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + # enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) out_val = torch.empty( (*input.shape[:-1], input.shape[-1] // 2), dtype=torch.uint8, @@ -669,9 +670,11 @@ def fp4_quantize( assert input.shape[-1] % sf_vec_size == 0 if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + # enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) # get input device sm version - major, minor = get_compute_capability(input.device) + # major, minor = get_compute_capability(input.device) + major, minor = get_compute_capability(input.place) x_q, sf = get_fp4_quantization_module(f"{major}{minor}").fp4_quantize_sm100( input, global_scale, diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 2ce2a8b6d0..e7e8d7cda9 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -20,6 +20,10 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch +import paddle + +with paddle.compat.use_torch_proxy_guard(enable=False): + import tvm_ffi from ..autotuner import ( AutoTuner, @@ -350,11 +354,15 @@ def __init__( ) self.activation_type = activation_type + def paddle_dtype_to_tvm_ffi_dtype(dtype: paddle.dtype): + dtype_str = str(dtype).split(".", 1)[-1] + return tvm_ffi.dtype(dtype_str) + if instance_key not in MoERunner.runner_dict: MoERunner.runner_dict[instance_key] = module.init( - x_dtype, - weight_dtype, - output_dtype, + paddle_dtype_to_tvm_ffi_dtype(x_dtype), + paddle_dtype_to_tvm_ffi_dtype(weight_dtype), + paddle_dtype_to_tvm_ffi_dtype(output_dtype), use_deepseek_fp8_block_scale, use_w4_group_scaling, use_mxfp8_act_scaling, @@ -454,7 +462,8 @@ def cutlass_fused_moe( activation_type: ActivationType = ActivationType.Swiglu, ) -> List[torch.Tensor]: if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + # enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) tuner = AutoTuner.get() MoERunner.refine_tuning_config(tune_max_num_tokens) @@ -513,17 +522,22 @@ def cutlass_fused_moe( else moe_runner.fused_moe_runner.run_moe ) num_active_experts_per_node = torch.empty( - (1,), dtype=torch.int32, device=input.device + # (1,), dtype=torch.int32, device=input.device + (1,), + dtype=torch.int32, + device=input.place, ) experts_to_token_score = torch.empty( (fc2_expert_weights.shape[0], input.shape[0]), dtype=torch.float32, - device=input.device, + # device=input.device, + device=input.place, ) active_expert_global_ids = torch.empty( (fc2_expert_weights.shape[0],), dtype=torch.int32, - device=input.device, + # device=input.device, + device=input.place, ) min_latency_output = ( [ @@ -799,7 +813,8 @@ def cutlass_fused_moe( ) if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + # enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.place) num_rows = input.shape[0] if min_latency_mode: @@ -808,10 +823,16 @@ def cutlass_fused_moe( output_shape = (num_rows, hidden_size) if output is None: - output = torch.empty(output_shape, dtype=output_dtype, device=input.device) + # output = torch.empty(output_shape, dtype=output_dtype, device=input.device) + output = torch.empty(output_shape, dtype=output_dtype, device=input.place) else: check_shape_dtype_device( - output, output_shape, output_dtype, input.device, "output" + # output, output_shape, output_dtype, input.device, "output" + output, + output_shape, + output_dtype, + input.place, + "output", ) return get_cutlass_fused_moe_module(device_arch).cutlass_fused_moe( diff --git a/flashinfer/jit/core.py b/flashinfer/jit/core.py index 2eec7ac2ce..d2fbf5687a 100644 --- a/flashinfer/jit/core.py +++ b/flashinfer/jit/core.py @@ -1,6 +1,10 @@ import dataclasses import logging import os +import paddle + +with paddle.compat.use_torch_proxy_guard(enable=False): + import tvm_ffi from contextlib import nullcontext from datetime import datetime from pathlib import Path diff --git a/flashinfer/jit/cpp_ext.py b/flashinfer/jit/cpp_ext.py index 2c3a56d92b..0dab4fc584 100644 --- a/flashinfer/jit/cpp_ext.py +++ b/flashinfer/jit/cpp_ext.py @@ -10,7 +10,10 @@ from pathlib import Path from typing import List, Optional -import tvm_ffi +import paddle + +with paddle.compat.use_torch_proxy_guard(enable=False): + import tvm_ffi import torch from . import env as jit_env diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 936d08380c..f48d98dbe3 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -16,13 +16,12 @@ import functools import math +import os from enum import Enum from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple, Union import torch import torch.version -from torch.torch_version import TorchVersion -from torch.torch_version import __version__ as torch_version from .jit.spdlog import gen_spdlog_module @@ -249,6 +248,7 @@ def canonicalize_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype: @functools.cache def get_compute_capability(device: torch.device) -> Tuple[int, int]: + return torch.device.cuda.get_device_capability(device.gpu_device_id()) if device.type != "cuda": raise ValueError("device must be a cuda device") return torch.cuda.get_device_capability(device.index) @@ -267,7 +267,13 @@ def _check_cached_qkv_data_type( ) -if TorchVersion(torch_version) < TorchVersion("2.4"): +def use_paddle_compatible_api() -> bool: + return os.environ.get("PADDLE_COMPATIBLE_API", "0").lower() in ["1", "on", "true"] + + +if use_paddle_compatible_api() or torch.torch_version.TorchVersion( + torch.torch_version.__version__ +) < torch.torch_version.TorchVersion("2.4"): def register_custom_op( name: str, @@ -522,7 +528,7 @@ def check_shape_dtype_device( expected_device: Optional[torch.device], name: str, ) -> None: - if expected_shape and x.shape != torch.Size(expected_shape): + if expected_shape and tuple(x.shape) != torch.Size(expected_shape): raise ValueError( f"Invalid shape of {name}: expected {expected_shape}, got {x.shape}" ) @@ -530,7 +536,8 @@ def check_shape_dtype_device( raise ValueError( f"Invalid dtype of {name}: expected {expected_dtype}, got {x.dtype}" ) - if expected_device and x.device != expected_device: + # if expected_device and x.device != expected_device: + if expected_device and x.place != expected_device: raise ValueError( f"Invalid device of {name}: expected {expected_device}, got {x.device}" ) @@ -566,8 +573,8 @@ def set_log_level(lvl_str: str) -> None: @functools.cache def device_support_pdl(device: torch.device) -> bool: - if device.type != "cuda": - return False + # if device.type != "cuda": + # return False major, _ = get_compute_capability(device) return major >= 9 From 95f1bf52a00d82537bb4721dbcf75a90f3a5bf05 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Mon, 13 Oct 2025 02:36:22 +0000 Subject: [PATCH 5/9] remove torch from requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a31b6ebdc8..a71e497d28 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,5 +9,5 @@ nvidia-ml-py packaging>=24.2 requests tabulate -torch +# torch tqdm From 955aedf4bd70739b7ee1179ed4addbc0f4121d04 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Thu, 23 Oct 2025 05:01:38 +0000 Subject: [PATCH 6/9] remove changes about import tvm_ffi --- flashinfer/fused_moe/core.py | 8 +++----- flashinfer/jit/core.py | 4 ---- flashinfer/jit/cpp_ext.py | 5 +---- 3 files changed, 4 insertions(+), 13 deletions(-) diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index e7e8d7cda9..14d1170f01 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -20,10 +20,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch -import paddle - -with paddle.compat.use_torch_proxy_guard(enable=False): - import tvm_ffi from ..autotuner import ( AutoTuner, @@ -354,7 +350,9 @@ def __init__( ) self.activation_type = activation_type - def paddle_dtype_to_tvm_ffi_dtype(dtype: paddle.dtype): + def paddle_dtype_to_tvm_ffi_dtype(dtype): + import tvm_ffi + dtype_str = str(dtype).split(".", 1)[-1] return tvm_ffi.dtype(dtype_str) diff --git a/flashinfer/jit/core.py b/flashinfer/jit/core.py index d2fbf5687a..2eec7ac2ce 100644 --- a/flashinfer/jit/core.py +++ b/flashinfer/jit/core.py @@ -1,10 +1,6 @@ import dataclasses import logging import os -import paddle - -with paddle.compat.use_torch_proxy_guard(enable=False): - import tvm_ffi from contextlib import nullcontext from datetime import datetime from pathlib import Path diff --git a/flashinfer/jit/cpp_ext.py b/flashinfer/jit/cpp_ext.py index 0dab4fc584..2c3a56d92b 100644 --- a/flashinfer/jit/cpp_ext.py +++ b/flashinfer/jit/cpp_ext.py @@ -10,10 +10,7 @@ from pathlib import Path from typing import List, Optional -import paddle - -with paddle.compat.use_torch_proxy_guard(enable=False): - import tvm_ffi +import tvm_ffi import torch from . import env as jit_env From 7476b5a28acf36355fbe724620c79cab85b7ee41 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Tue, 4 Nov 2025 11:17:42 +0000 Subject: [PATCH 7/9] remove dtype conversion --- flashinfer/fused_moe/core.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 14d1170f01..a34b89597d 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -350,17 +350,11 @@ def __init__( ) self.activation_type = activation_type - def paddle_dtype_to_tvm_ffi_dtype(dtype): - import tvm_ffi - - dtype_str = str(dtype).split(".", 1)[-1] - return tvm_ffi.dtype(dtype_str) - if instance_key not in MoERunner.runner_dict: MoERunner.runner_dict[instance_key] = module.init( - paddle_dtype_to_tvm_ffi_dtype(x_dtype), - paddle_dtype_to_tvm_ffi_dtype(weight_dtype), - paddle_dtype_to_tvm_ffi_dtype(output_dtype), + x_dtype, + weight_dtype, + output_dtype, use_deepseek_fp8_block_scale, use_w4_group_scaling, use_mxfp8_act_scaling, From 71fce0a27aabbe70dbeb3f429a091687c66f68cf Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Fri, 14 Nov 2025 11:37:56 +0800 Subject: [PATCH 8/9] 1.fix fp4gemm pd api dismatch 2.fix cudagraph error of nvfp4 3.fix attention pd api dismatch --- flashinfer/autotuner.py | 2 +- flashinfer/decode.py | 68 +++++++++---------- flashinfer/fp4_quantization.py | 33 +++++----- flashinfer/gemm.py | 116 ++++++++++++++++----------------- flashinfer/prefill.py | 102 ++++++++++++++--------------- flashinfer/quantization.py | 8 +-- flashinfer/utils.py | 4 +- 7 files changed, 168 insertions(+), 165 deletions(-) diff --git a/flashinfer/autotuner.py b/flashinfer/autotuner.py index 902f659a1d..e8e1816e32 100644 --- a/flashinfer/autotuner.py +++ b/flashinfer/autotuner.py @@ -741,7 +741,7 @@ def _create_tensor_like( but with dimensions specified by the dims parameter. """ dtype = origin_tensor.dtype - device = origin_tensor.device + device = origin_tensor.place shapes = [] for d in dims: if isinstance(d, StaticDim): diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 467152af38..aa98ee54e0 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -320,7 +320,7 @@ def single_decode_with_kv_cache_with_jit_module( window_left: int = -1, return_lse: bool = False, ): - device = q.device + device = q.place tmp = _get_cache_buf("single_decode_with_kv_cache_tmp", 32 * 1024 * 1024, device) o = torch.empty_like(q) if return_lse: @@ -483,7 +483,7 @@ def single_decode_with_kv_cache( """ _check_pos_encoding_mode(pos_encoding_mode) _check_kv_layout(kv_layout) - tmp = _get_cache_buf("single_decode_with_kv_cache_tmp", 32 * 1024 * 1024, q.device) + tmp = _get_cache_buf("single_decode_with_kv_cache_tmp", 32 * 1024 * 1024, q.place) head_dim = q.shape[-1] if logits_soft_cap is None: logits_soft_cap = 0.0 @@ -501,7 +501,7 @@ def single_decode_with_kv_cache( lse = None if return_lse: - lse = torch.empty((num_qo_heads,), dtype=torch.float32, device=q.device) + lse = torch.empty((num_qo_heads,), dtype=torch.float32, device=q.place) if use_tensor_cores: out = torch.empty_like(q.unsqueeze(0)) @@ -527,7 +527,7 @@ def single_decode_with_kv_cache( TensorLayout[kv_layout].value, window_left, None, # packed_custom_mask - _get_cache_alibi_slopes_buf(num_qo_heads, q.device), + _get_cache_alibi_slopes_buf(num_qo_heads, q.place), logits_soft_cap, sm_scale, None, # scale_q, not supported yet @@ -557,7 +557,7 @@ def single_decode_with_kv_cache( tmp, out, lse, - _get_cache_alibi_slopes_buf(num_qo_heads, q.device), + _get_cache_alibi_slopes_buf(num_qo_heads, q.place), TensorLayout[kv_layout].value, window_left, logits_soft_cap, @@ -722,7 +722,7 @@ def __init__( self._kv_layout = kv_layout self._float_workspace_buffer = float_workspace_buffer - self.device = float_workspace_buffer.device + self.device = float_workspace_buffer.place self._int_workspace_buffer = torch.empty( (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device ) @@ -730,7 +730,7 @@ def __init__( (8 * 1024 * 1024,), dtype=torch.uint8, pin_memory=True, - device="cpu", + device="cuda", ) self._kv_lens_buffer: Optional[torch.Tensor] = None if backend == "trtllm-gen": @@ -771,7 +771,7 @@ def __init__( self._qo_indptr_buf = torch.arange( self._fixed_batch_size + 1, dtype=torch.int32, - device=float_workspace_buffer.device, + device=float_workspace_buffer.place, ) self._backend = backend @@ -803,7 +803,7 @@ def reset_workspace_buffer( self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape, dtype=self._int_workspace_buffer.dtype, - device="cpu", + device="cuda", pin_memory=True, ) @@ -934,7 +934,7 @@ def plan( last_page_len, non_blocking=non_blocking ) self._paged_kv_indices_buf[: len(indices)].copy_( - indices, non_blocking=(indices.device == self.device) and non_blocking + indices, non_blocking=(indices.place == self.device) and non_blocking ) else: self._paged_kv_indptr_buf = indptr.to( @@ -1224,7 +1224,7 @@ def run( * logsumexp of attention scores, shape: ``[batch_size, num_qo_heads]``. """ if enable_pdl is None: - enable_pdl = device_support_pdl(q.device) + enable_pdl = device_support_pdl(q.place) k_cache, v_cache = _unpack_paged_kv_cache(paged_kv_cache, self._kv_layout) if self._kv_layout == "NHD": page_size = k_cache.shape[1] @@ -1262,17 +1262,17 @@ def run( if return_lse: if lse is None: lse = torch.empty( - (q.size(0), q.size(1)), dtype=torch.float32, device=q.device + (q.size(0), q.size(1)), dtype=torch.float32, device=q.place ) else: check_shape_dtype_device( - lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse" + lse, (q.size(0), q.size(1)), torch.float32, q.place, "lse" ) if out is None: out = torch.empty_like(q) else: - check_shape_dtype_device(out, q.shape, q.dtype, q.device, "out") + check_shape_dtype_device(out, q.shape, q.dtype, q.place, "out") if self._backend == "trtllm-gen": q = q.view(q.size(0) // q_len_per_req, q_len_per_req, q.size(1), q.size(2)) @@ -1303,7 +1303,7 @@ def run( run_args += [ None, # packed_custom_mask None, # mask_indptr_buf - _get_cache_alibi_slopes_buf(q.shape[1], q.device), + _get_cache_alibi_slopes_buf(q.shape[1], q.place), None, # maybe_prefix_len_ptr None, # maybe_token_pos_in_items_ptr None, # maybe_max_item_len_ptr @@ -1356,7 +1356,7 @@ def run( run_args.extend(list(args)) else: run_args += [ - _get_cache_alibi_slopes_buf(q.shape[1], q.device), + _get_cache_alibi_slopes_buf(q.shape[1], q.place), logits_soft_cap, sm_scale, rope_scale, @@ -1530,7 +1530,7 @@ def __init__( Only needed when ``use_cuda_graph`` is ``True``. """ self._float_workspace_buffer = float_workspace_buffer - self.device = float_workspace_buffer.device + self.device = float_workspace_buffer.place self._int_workspace_buffer = torch.empty( (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device ) @@ -1538,7 +1538,7 @@ def __init__( (8 * 1024 * 1024,), dtype=torch.uint8, pin_memory=True, - device="cpu", + device="cuda", ) if use_cuda_graph: @@ -1596,7 +1596,7 @@ def reset_workspace_buffer( self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape, dtype=self._int_workspace_buffer.dtype, - device="cpu", + device="cuda", pin_memory=True, ) @@ -1779,7 +1779,7 @@ def run( """ # MLA decode kernel supports SM80 only - major, minor = get_compute_capability(q_nope.device) + major, minor = get_compute_capability(q_nope.place) device_arch = major * 10 + minor if device_arch != 80: raise GPUArchitectureError( @@ -1807,7 +1807,7 @@ def run( out = torch.empty_like(q_nope, device=device) else: check_shape_dtype_device( - out, q_nope.shape, q_nope.dtype, q_nope.device, "out" + out, q_nope.shape, q_nope.dtype, q_nope.place, "out" ) if return_lse: @@ -1822,7 +1822,7 @@ def run( lse, (q_nope.size(0), q_nope.size(1)), q_nope.dtype, - q_nope.device, + q_nope.place, "lse", ) self._cached_module.run( @@ -1883,7 +1883,7 @@ def _paged_run( if out is None: out = torch.empty_like(query) if self._sm_count is None: - self._sm_count = get_device_sm_count(query.device) + self._sm_count = get_device_sm_count(query.place) bmm1_scale = ( bmm1_scale.item() if isinstance(bmm1_scale, torch.Tensor) else bmm1_scale @@ -2125,7 +2125,7 @@ def trtllm_batch_decode_with_kv_cache( out : Union[torch.Tensor, FP4Tensor] output torch.Tensor or FP4Tensor. """ - enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl + enable_pdl = device_support_pdl(query.place) if enable_pdl is None else enable_pdl if isinstance(kv_cache, tuple): k_cache, v_cache = kv_cache @@ -2141,7 +2141,7 @@ def trtllm_batch_decode_with_kv_cache( k_cache, v_cache = kv_cache.unbind(dim=1) run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_decode - sm_count = get_device_sm_count(query.device) + sm_count = get_device_sm_count(query.place) if out_dtype == "nvfp4" or (out_dtype is None and isinstance(out, FP4Tensor)): assert query.dtype == torch.float8_e4m3fn, ( @@ -2169,10 +2169,10 @@ def trtllm_batch_decode_with_kv_cache( round_up(query.shape[1] * query.shape[2] // o_sf_vec_size, 4), ) out_scale_factor = torch.empty( - fp4_out_scale_shape, dtype=torch.float8_e4m3fn, device=query.device + fp4_out_scale_shape, dtype=torch.float8_e4m3fn, device=query.place ) o_sf_start_index = 0 - out = torch.empty(fp4_out_shape, dtype=torch.uint8, device=query.device) + out = torch.empty(fp4_out_shape, dtype=torch.uint8, device=query.place) else: raise ValueError(f"Invalid out: {out}") @@ -2180,13 +2180,13 @@ def trtllm_batch_decode_with_kv_cache( assert isinstance(out, torch.Tensor) # Use uint8 as the container dtype to compliant with next fp4 gemm. - check_shape_dtype_device(out, fp4_out_shape, torch.uint8, query.device, "out") + check_shape_dtype_device(out, fp4_out_shape, torch.uint8, query.place, "out") check_shape_dtype_device( out_scale_factor, fp4_out_scale_shape, torch.float8_e4m3fn, - query.device, + query.place, "out_scale_factor", ) @@ -2211,7 +2211,7 @@ def trtllm_batch_decode_with_kv_cache( out = out if out is not None else torch.empty_like(query, dtype=out_dtype) if out_dtype not in (query.dtype, torch.float16, torch.bfloat16): raise ValueError(f"Unsupported out_dtype: {out_dtype}") - check_shape_dtype_device(out, query.shape, out_dtype, query.device, "out") + check_shape_dtype_device(out, query.shape, out_dtype, query.place, "out") else: raise ValueError(f"Invalid out_dtype: {out_dtype}") @@ -2349,9 +2349,9 @@ def trtllm_batch_decode_with_kv_cache_mla( - Currently, only fp8 tensor core operation supports this mode. When both are provided, the dynamic scale factor tensors will be used. """ - enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl + enable_pdl = device_support_pdl(query.place) if enable_pdl is None else enable_pdl run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_decode - sm_count = get_device_sm_count(query.device) + sm_count = get_device_sm_count(query.place) block_size = kv_cache.size(-2) if ( @@ -2371,14 +2371,14 @@ def trtllm_batch_decode_with_kv_cache_mla( if out is None: out_shape = query.shape[:-1] + (kv_lora_rank,) - out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device) + out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.place) else: batch_size, _, num_q_heads, _ = query.shape check_shape_dtype_device( out, [batch_size, num_q_heads, kv_lora_rank], torch.bfloat16, - query.device, + query.place, "out", ) diff --git a/flashinfer/fp4_quantization.py b/flashinfer/fp4_quantization.py index b94a9720a4..be2d21c8d9 100644 --- a/flashinfer/fp4_quantization.py +++ b/flashinfer/fp4_quantization.py @@ -185,17 +185,20 @@ def fp4_quantize_sm100( out_val = torch.empty( (*input.shape[:-1], input.shape[-1] // 2), dtype=torch.uint8, - device=input.device, + device=input.place, ) - m = input.numel() // input.shape[-1] - k = input.shape[-1] + # m = input.numel() // input.shape[-1] + # k = input.shape[-1] + m = torch.shape(input)[0] + k = torch.shape(input)[1] if is_sf_swizzled_layout: out_sf_size = _compute_swizzled_layout_sf_size( m, k // sf_vec_size, 8 if is_sf_8x4_layout else 128 ) else: out_sf_size = m * k // sf_vec_size - out_sf = torch.empty((out_sf_size,), dtype=torch.uint8, device=input.device) + out_sf = torch.empty((out_sf_size,), dtype=torch.uint8, device=input.place) + module.fp4_quantize( input, global_scale, @@ -235,7 +238,7 @@ def mxfp4_dequantize_host( out = torch.empty( (weight.shape[0], weight.shape[1] * 2), dtype=torch.float32, - device=weight.device, + device=weight.place, ) module.mxfp4_dequantize_host( weight, @@ -277,7 +280,7 @@ def block_scale_interleave_sm100( out = torch.empty( (num_experts * expert_out_size,), dtype=torch.uint8, - device=unswizzled_sf.device, + device=unswizzled_sf.place, ) module.block_scale_interleave_sm100(unswizzled_sf, out) return out @@ -341,12 +344,12 @@ def fp4_batched_quantize_sm100( out_val = torch.empty( (b, m, k // 2), dtype=torch.uint8, - device=input.device, + device=input.place, ) out_sf = torch.empty( (b, _compute_swizzled_layout_sf_size(m, k // sf_vec_size, 128)), dtype=torch.uint8, - device=input.device, + device=input.place, ) module.fp4_batched_quantize( input, @@ -423,12 +426,12 @@ def silu_and_mul_nvfp4_batched_quantize_sm100( out_val = torch.empty( (b, m, k // 4), dtype=torch.uint8, - device=input.device, + device=input.place, ) out_sf = torch.empty( (b, _compute_swizzled_layout_sf_size(m, k // (2 * sf_vec_size), 128)), dtype=torch.uint8, - device=input.device, + device=input.place, ) module.silu_and_mul_nvfp4_batched_quantize( input, @@ -610,7 +613,7 @@ def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor: f"Input dtype must be uint8, got {unswizzled_sf.dtype}" ) - major, minor = get_compute_capability(unswizzled_sf.device) + major, minor = get_compute_capability(unswizzled_sf.place) device_arch = f"{major * 10 + minor}" return get_fp4_quantization_module(device_arch).block_scale_interleave_sm100( @@ -670,7 +673,7 @@ def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch. """ row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m) - return input_tensor[row_indices.to(input_tensor.device)] + return input_tensor[row_indices.to(input_tensor.place)] def shuffle_matrix_sf_a( @@ -691,7 +694,7 @@ def shuffle_matrix_sf_a( row_indices = get_shuffle_matrix_sf_a_row_indices(input_tensor, epilogue_tile_m) - w_shuffled = input_tensor[row_indices.to(input_tensor.device)] + w_shuffled = input_tensor[row_indices.to(input_tensor.place)] # 128x4 return block_scale_interleave(w_shuffled) @@ -798,7 +801,7 @@ def mxfp4_dequantize(a_fp4, a_sf): return e2m1_and_ufp8sf_scale_to_float( a_fp4.cpu().view(torch.uint8), a_sf.cpu().view(torch.uint8).reshape(-1), - torch.tensor([1.0], device=a_fp4.device), + torch.tensor([1.0], device=a_fp4.place), 32, 0, True, @@ -853,7 +856,7 @@ def nvfp4_batched_quantize( - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2 - Scale factors tensor with shape determined by layout and sf_vec_size """ - major, minor = get_compute_capability(a.device) + major, minor = get_compute_capability(a.place) device_arch = f"{major * 10 + minor}" a_fp4, a_sf = get_fp4_quantization_module(device_arch).fp4_batched_quantize_sm100( a, diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index ffb2e4825f..3f9a562a68 100644 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -357,8 +357,8 @@ def fp8_gemm_sm100( runners = [] # No e5m2 for cutlass is_e5m2 = a.dtype == torch.float8_e5m2 or b.dtype == torch.float8_e5m2 - is_sm_supported = _match_sm_version(a.device, ["100", "103", "110"]) - is_sm120_supported = _match_sm_version(a.device, ["120", "121"]) + is_sm_supported = _match_sm_version(a.place, ["100", "103", "110"]) + is_sm120_supported = _match_sm_version(a.place, ["120", "121"]) if "cutlass" in runner_names and not is_e5m2: if is_sm_supported: @@ -558,7 +558,7 @@ def forward( # from [m,k]x[k,n]+[n,] to [n,k]x[k,m]+[n,] gemm_fn = module.tgv_gemm c = torch.empty( - (a.shape[0], b.shape[1]), dtype=a.dtype, device=a.device + (a.shape[0], b.shape[1]), dtype=a.dtype, device=a.place ) gemm_fn(b.t(), a.t(), bias, tactic, c, pdl) return c.t() @@ -601,7 +601,7 @@ def tgv_gemm_sm100( - Tensor b is expected to be in column-major layout (transposed from typical PyTorch row-major) """ # Verify SM100 architecture support - if not _match_sm_version(a.device, ["100", "103"]): + if not _match_sm_version(a.place, ["100", "103"]): raise ValueError("TGV GEMM requires SM100, SM103 architecture") # Verify dtype support @@ -616,7 +616,7 @@ def tgv_gemm_sm100( ) runners = [] - use_sm_100f = is_sm100f_supported(a.device) + use_sm_100f = is_sm100f_supported(a.place) runners.append(get_tgv_gemm_sm10x_module(a.dtype, use_sm_100f).tgv_gemm_runner()) tuner = AutoTuner.get() @@ -718,7 +718,7 @@ def launch_compute_sm80_group_gemm_args( seg_indptr: torch.Tensor, weight_indices: Optional[torch.Tensor] = None, ): - device = x.device + device = x.place prob_type = torch.int32 # problem sizes -> int ptr_type = torch.int64 # pointers -> int64_t ld_type = torch.int64 # strides -> int64_t @@ -780,7 +780,7 @@ def launch_compute_sm90_group_gemm_args( seg_indptr: torch.Tensor, weight_indices: Optional[torch.Tensor] = None, ): - device = x.device + device = x.place prob_type = torch.int32 # problem sizes -> int ptr_type = torch.int64 # pointers -> int64_t stride_type = torch.int64 # strides -> int64_t @@ -896,7 +896,7 @@ def __init__( segment GEMM kernels. Encouraged size is 128MB. """ self._int_workspace_buffer = torch.empty( - (1024 * 1024,), dtype=torch.int8, device=float_workspace_buffer.device + (1024 * 1024,), dtype=torch.int8, device=float_workspace_buffer.place ) self._float_workspace_buffer = float_workspace_buffer self.backend = backend @@ -1003,18 +1003,18 @@ def run( else: out_dtype = x.dtype out = torch.zeros( - (cumulative_batch_size, d_out), dtype=out_dtype, device=x.device + (cumulative_batch_size, d_out), dtype=out_dtype, device=x.place ) else: if out.shape != (cumulative_batch_size, d_out): raise ValueError( f"Output tensor shape mismatch, expected {cumulative_batch_size, d_out}, got {out.shape}" ) - empty_x_data = torch.empty(0, dtype=x.dtype, device=x.device) - empty_y_data = torch.empty(0, dtype=out.dtype, device=out.device) + empty_x_data = torch.empty(0, dtype=x.dtype, device=x.place) + empty_y_data = torch.empty(0, dtype=out.dtype, device=out.place) if self.backend == "auto": - backend = determine_gemm_backend(x.device) + backend = determine_gemm_backend(x.place) else: backend = self.backend @@ -1308,10 +1308,10 @@ def execute_cudnn_gemm_fp4_graph( if workspace_buffer.numel() < graph.get_workspace_size(): workspace_buffer = torch.empty( - graph.get_workspace_size(), device=a.device, dtype=torch.uint8 + graph.get_workspace_size(), device=a.place, dtype=torch.uint8 ) - stream = torch.cuda.current_stream(a.device) + stream = torch.cuda.current_stream(a.place) graph.execute(variant_pack, workspace_buffer, handle=_get_cudnn_handle(stream)) @@ -1408,12 +1408,12 @@ def execute_cudnn_gemm_with_per_tensor_q_graph( UIDs.O_UID.value: c_final, } - stream = torch.cuda.current_stream(a.device) + stream = torch.cuda.current_stream(a.place) cudnn_handle = _get_cudnn_handle(stream) if workspace.numel() < graph.get_workspace_size(): workspace = torch.empty( - graph.get_workspace_size(), device=a.device, dtype=torch.uint8 + graph.get_workspace_size(), device=a.place, dtype=torch.uint8 ) graph.execute(variant_pack, workspace, handle=cudnn_handle) @@ -1451,7 +1451,7 @@ def _cudnn_gemm_fp8( _torch_data_type_to_cudnn_data_type(a.dtype), _torch_data_type_to_cudnn_data_type(b.dtype), _torch_data_type_to_cudnn_data_type(torch_out_dtype), - a.device, + a.place, ) execute_cudnn_gemm_with_per_tensor_q_graph( @@ -1615,7 +1615,7 @@ def mm_fp8( ) out = torch.empty( (m, n), - device=a.device, + device=a.place, dtype=out_dtype, ) else: @@ -1628,9 +1628,9 @@ def mm_fp8( raise ValueError( f"Output shape mismatch. Expected {a.shape[0], b.shape[1]}, got {out.shape}." ) - if out.device != a.device: + if out.place != a.place: raise ValueError( - f"Output device mismatch. Expected {a.device}, got {out.device}." + f"Output device mismatch. Expected {a.place}, got {out.place}." ) if out_dtype is not None and out.dtype != out_dtype: raise ValueError( @@ -1746,8 +1746,8 @@ def mm_fp4( ) if alpha is not None and alpha.dtype != torch.float: raise ValueError(f"alpha must be a float tensor, got {alpha.dtype}") - if alpha is not None and alpha.numel() != 1: - raise ValueError(f"alpha must be a scalar, got {alpha.numel()}") + # if alpha is not None and alpha.numel() != 1: + # raise ValueError(f"alpha must be a scalar, got {alpha.numel()}") if out_dtype not in (torch.bfloat16, torch.float16): raise ValueError( @@ -1761,14 +1761,14 @@ def mm_fp4( raise ValueError("mxfp4 supports block_size = 32.") if backend != "trtllm" and use_8x4_sf_layout: raise ValueError("Only TRTLLM FP4 GEMM supports 8x4 scale factor layout.") - if backend == "trtllm" and _match_sm_version(a.device, ["110"]): + if backend == "trtllm" and _match_sm_version(a.place, ["110"]): raise ValueError("TRTLLM FP4 GEMM is not supported on SM110.") if backend != "cudnn" and not use_nvfp4: raise ValueError("Only cudnn FP4 GEMM supports mxfp4 quantization.") if ( backend == "cudnn" and not use_nvfp4 - and _match_sm_version(a.device, ["120"]) + and _match_sm_version(a.place, ["120"]) and cudnn.backend_version() < 91400 ): raise LibraryError( @@ -1779,12 +1779,12 @@ def mm_fp4( if out is None: out = torch.empty( (a.shape[0], b.shape[1]), - device=a.device, + device=a.place, dtype=out_dtype, ) workspace_buffer = _get_cache_buf( - "mm_fp4_workspace", DEFAULT_WORKSPACE_SIZE, a.device + "mm_fp4_workspace", DEFAULT_WORKSPACE_SIZE, a.place ) if backend == "cudnn": @@ -1816,7 +1816,7 @@ def mm_fp4( cudnn.data_type.FP4_E2M1, _torch_data_type_to_cudnn_data_type(out_dtype), block_size, - a.device, + a.place, alpha is not None, use_nvfp4, ) @@ -1850,7 +1850,7 @@ def mm_fp4( b_descale = b_descale.view(torch.uint8) # Dispatch to the correct module based on device architecture - major, _ = get_compute_capability(a.device) + major, _ = get_compute_capability(a.place) if major == 12: gemm_module = get_gemm_sm120_module_cutlass_fp4() else: @@ -1931,12 +1931,12 @@ def bmm_fp8( if out is None: out = torch.empty( (A.shape[0], A.shape[1], B.shape[2]), - device=A.device, + device=a.place, dtype=dtype, ) workspace_buffer = _get_cache_buf( - "bmm_fp8_workspace", DEFAULT_WORKSPACE_SIZE, A.device + "bmm_fp8_workspace", DEFAULT_WORKSPACE_SIZE, a.place ) if backend == "cudnn": @@ -2029,11 +2029,11 @@ def gemm_fp8_nt_groupwise( ----- The ``m`` should be padded to a multiple of 4 before calling this function, to accommodate the kernel's requirement. """ - if backend == "trtllm" and _match_sm_version(a.device, ["110"]): + if backend == "trtllm" and _match_sm_version(a.place, ["110"]): raise ValueError("TRTLLM FP8 GEMM is not supported on SM110.") workspace_buffer = _get_cache_buf( - "gemm_fp8_nt_groupwise_workspace", DEFAULT_WORKSPACE_SIZE, a.device + "gemm_fp8_nt_groupwise_workspace", DEFAULT_WORKSPACE_SIZE, a.place ) if a.ndim != 2 or b.ndim != 2: raise ValueError(f"Shape mismatch. a.shape = {a.shape}, b.shape = {b.shape}") @@ -2060,24 +2060,24 @@ def gemm_fp8_nt_groupwise( out = torch.empty( a.shape[0], b.shape[0], - device=a.device, + device=a.place, dtype=out_dtype, ) if backend == "cutlass": - if not _match_sm_version(a.device, ["100", "103", "110", "120", "121"]): + if not _match_sm_version(a.place, ["100", "103", "110", "120", "121"]): raise ValueError( "gemm_fp8_nt_groupwise is only supported on SM100, SM103, SM110, SM120, or SM121 in cutlass backend." ) elif backend == "trtllm": - if not _match_sm_version(a.device, ["100", "103"]): + if not _match_sm_version(a.place, ["100", "103"]): raise ValueError( "gemm_fp8_nt_groupwise is only supported on SM100, SM103 in trtllm backend." ) if backend == "cutlass": assert scale_major_mode is not None - if is_sm120a_supported(a.device) or is_sm121a_supported(a.device): + if is_sm120a_supported(a.place) or is_sm121a_supported(a.place): # SM120/121 doesn't use mma_sm parameter get_gemm_sm120_module().gemm_fp8_nt_groupwise( workspace_buffer, @@ -2089,7 +2089,7 @@ def gemm_fp8_nt_groupwise( *scale_granularity_mnk, scale_major_mode, ) - elif is_sm100a_supported(a.device): + elif is_sm100a_supported(a.place): get_gemm_sm100_module().gemm_fp8_nt_groupwise( workspace_buffer, a, @@ -2102,7 +2102,7 @@ def gemm_fp8_nt_groupwise( mma_sm, ) else: - raise ValueError(f"Unsupported device for FP8 GEMM: {a.device}") + raise ValueError(f"Unsupported device for FP8 GEMM: {a.place}") elif backend == "trtllm": assert scale_granularity_mnk == (1, 128, 128) assert a.shape[1] >= 256 @@ -2357,23 +2357,23 @@ def group_gemm_fp8_nt_groupwise( to accommodate the kernel's requirement. """ if ( - not is_sm100a_supported(a.device) - and not is_sm120a_supported(a.device) - and not is_sm121a_supported(a.device) + not is_sm100a_supported(a.place) + and not is_sm120a_supported(a.place) + and not is_sm121a_supported(a.place) ): raise ValueError( "gemm_fp8_nt_groupwise is only supported on SM100, SM120, and SM121." ) - if not (_match_sm_version(a.device, ["100", "103", "110", "120", "121"])): + if not (_match_sm_version(a.place, ["100", "103", "110", "120", "121"])): raise ValueError( "gemm_fp8_nt_groupwise is only supported on SM100, SM103, SM110, SM120, or SM121." ) int_workspace_buffer = _get_cache_buf( - "group_gemm_fp8_nt_groupwise_int_workspace", DEFAULT_WORKSPACE_SIZE, a.device + "group_gemm_fp8_nt_groupwise_int_workspace", DEFAULT_WORKSPACE_SIZE, a.place ) float_workspace_buffer = _get_cache_buf( - "group_gemm_fp8_nt_groupwise_float_workspace", DEFAULT_WORKSPACE_SIZE, a.device + "group_gemm_fp8_nt_groupwise_float_workspace", DEFAULT_WORKSPACE_SIZE, a.place ) assert a.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] @@ -2405,12 +2405,12 @@ def group_gemm_fp8_nt_groupwise( out_shape = (a.shape[0], n) if out is None: - out = torch.empty(out_shape, dtype=out_dtype, device=a.device) + out = torch.empty(out_shape, dtype=out_dtype, device=a.place) else: assert out.shape == out_shape assert out.dtype == out_dtype - if is_sm120a_supported(a.device) or is_sm121a_supported(a.device): + if is_sm120a_supported(a.place) or is_sm121a_supported(a.place): # it has correctness issues for num_groups > 1 if num_groups > 1: raise RuntimeError( @@ -2431,7 +2431,7 @@ def group_gemm_fp8_nt_groupwise( *scale_granularity_mnk, scale_major_mode, ) - elif is_sm100a_supported(a.device): + elif is_sm100a_supported(a.place): get_gemm_sm100_module().group_gemm_fp8_nt_groupwise( int_workspace_buffer, float_workspace_buffer, @@ -2449,7 +2449,7 @@ def group_gemm_fp8_nt_groupwise( ) else: raise ValueError( - f"group_gemm_fp8_nt_groupwise requires SM100, SM120, or SM121, but got {a.device}" + f"group_gemm_fp8_nt_groupwise requires SM100, SM120, or SM121, but got {a.place}" ) return out @@ -2523,12 +2523,12 @@ def group_gemm_mxfp8_mxfp4_nt_groupwise( to accommodate the kernel's requirement. """ int_workspace_buffer = _get_cache_buf( - "group_gemm_mxfp4_nt_groupwise_int_workspace", DEFAULT_WORKSPACE_SIZE, a.device + "group_gemm_mxfp4_nt_groupwise_int_workspace", DEFAULT_WORKSPACE_SIZE, a.place ) float_workspace_buffer = _get_cache_buf( "group_gemm_mxfp4_nt_groupwise_float_workspace", DEFAULT_WORKSPACE_SIZE, - a.device, + a.place, ) assert a.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] @@ -2563,7 +2563,7 @@ def group_gemm_mxfp8_mxfp4_nt_groupwise( out_shape = (a.shape[0], n) if out is None: - out = torch.empty(out_shape, dtype=out_dtype, device=a.device) + out = torch.empty(out_shape, dtype=out_dtype, device=a.place) else: assert out.shape == out_shape assert out.dtype == out_dtype @@ -2600,12 +2600,12 @@ def pad_indptr_to_multiple_of_4( batch_size = m_indptr.shape[0] - 1 m = m_indptr[1:] - m_indptr[:-1] m = m + 3 - (m + 3) % 4 - padded_m_indptr = torch.cat((torch.zeros((1,), device=m.device, dtype=m.dtype), m)) + padded_m_indptr = torch.cat((torch.zeros((1,), device=m.place, dtype=m.dtype), m)) padded_m_indptr = padded_m_indptr.cumsum(dim=0, dtype=padded_m_indptr.dtype) - m_rank = torch.zeros((m_indptr[-1],), dtype=m_indptr.dtype, device=m_indptr.device) + m_rank = torch.zeros((m_indptr[-1],), dtype=m_indptr.dtype, device=m_indptr.place) padded_m_rank = torch.zeros( - (m_indptr[-1],), dtype=m_indptr.dtype, device=m_indptr.device + (m_indptr[-1],), dtype=m_indptr.dtype, device=m_indptr.place ) compute_padding_mapping[(batch_size,)]( @@ -2735,14 +2735,14 @@ def group_deepgemm_fp8_nt_groupwise( """ from flashinfer.deep_gemm import m_grouped_fp8_gemm_nt_contiguous - if not _match_sm_version(a.device, ["100", "103"]): + if not _match_sm_version(a.place, ["100", "103"]): raise ValueError( "m_grouped_fp8_gemm_nt_contiguous is only supported on SM100, SM100, SM103." ) if out is None: out_dtype = out_dtype or torch.bfloat16 - out = torch.empty(a.shape[0], b.shape[1], dtype=out_dtype, device=a.device) + out = torch.empty(a.shape[0], b.shape[1], dtype=out_dtype, device=a.place) m_grouped_fp8_gemm_nt_contiguous( (a, a_scale), (b, b_scale), out, m_indices, scale_granularity_mnk @@ -2868,7 +2868,7 @@ def batch_deepgemm_fp8_nt_groupwise( """ from flashinfer.deep_gemm import m_grouped_fp8_gemm_nt_masked - if not _match_sm_version(a.device, ["100", "103"]): + if not _match_sm_version(a.place, ["100", "103"]): raise ValueError( "m_grouped_fp8_gemm_nt_masked is only supported on SM100, SM103." ) @@ -2876,7 +2876,7 @@ def batch_deepgemm_fp8_nt_groupwise( if out is None: out_dtype = out_dtype or torch.bfloat16 out = torch.empty( - a.shape[0], a.shape[1], b.shape[1], dtype=out_dtype, device=a.device + a.shape[0], a.shape[1], b.shape[1], dtype=out_dtype, device=a.place ) m_grouped_fp8_gemm_nt_masked( diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index ac73ca9871..ed6dcaa1d2 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -201,7 +201,7 @@ def _paged_run( out: Optional[torch.Tensor] = None, sinks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - sm_count = get_device_sm_count(query.device) + sm_count = get_device_sm_count(query.place) if out is None: out = torch.empty_like(query) bmm1_scale = ( @@ -884,7 +884,7 @@ def single_prefill_with_kv_cache_with_jit_module( window_left: int = -1, return_lse: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - device = q.device + device = q.place tmp = _get_cache_buf( "single_prefill_with_kv_cache_tmp", 32 * 1024 * 1024, device=device ) @@ -1102,7 +1102,7 @@ def single_prefill_with_kv_cache( """ _check_pos_encoding_mode(pos_encoding_mode) _check_kv_layout(kv_layout) - tmp = _get_cache_buf("single_prefill_with_kv_cache_tmp", 32 * 1024 * 1024, q.device) + tmp = _get_cache_buf("single_prefill_with_kv_cache_tmp", 32 * 1024 * 1024, q.place) if logits_soft_cap is None: logits_soft_cap = 0.0 if sm_scale is None: @@ -1127,7 +1127,7 @@ def single_prefill_with_kv_cache( lse = None if return_lse: - lse = torch.empty((q.size(0), q.size(1)), dtype=torch.float32, device=q.device) + lse = torch.empty((q.size(0), q.size(1)), dtype=torch.float32, device=q.place) if is_float8(q): # FP8 quant enabled, do sanity check: @@ -1137,15 +1137,15 @@ def single_prefill_with_kv_cache( assert q.dtype == k.dtype == v.dtype assert q.shape[-1] == k.shape[-1] == v.shape[-1] if scale_q is None: - scale_q = torch.ones(q.shape[1], dtype=torch.float32, device=q.device) + scale_q = torch.ones(q.shape[1], dtype=torch.float32, device=q.place) if scale_k is None: - scale_k = torch.ones(k.shape[1], dtype=torch.float32, device=q.device) + scale_k = torch.ones(k.shape[1], dtype=torch.float32, device=q.place) if scale_v is None: - scale_v = torch.ones(v.shape[1], dtype=torch.float32, device=q.device) + scale_v = torch.ones(v.shape[1], dtype=torch.float32, device=q.place) if backend == "auto": backend = determine_attention_backend( - q.device, + q.place, PosEncodingMode[pos_encoding_mode].value, use_fp16_qk_reduction, packed_custom_mask is not None, # use_custom_mask @@ -1156,7 +1156,7 @@ def single_prefill_with_kv_cache( # o_dtype should be provided for FP8 attention if o_dtype is None: o_dtype = q.dtype - out = torch.empty(q.shape[:-1] + v.shape[-1:], dtype=o_dtype, device=q.device) + out = torch.empty(q.shape[:-1] + v.shape[-1:], dtype=o_dtype, device=q.place) module = get_single_prefill_module( backend, @@ -1182,7 +1182,7 @@ def single_prefill_with_kv_cache( TensorLayout[kv_layout].value, window_left, packed_custom_mask, - _get_cache_alibi_slopes_buf(q.shape[1], q.device), + _get_cache_alibi_slopes_buf(q.shape[1], q.place), logits_soft_cap, sm_scale, scale_q, @@ -1423,7 +1423,7 @@ def __init__( self._float_workspace_buffer.numel() * self._float_workspace_buffer.element_size() ) - self.device = float_workspace_buffer.device + self.device = float_workspace_buffer.place self._vector_sparse_indptr_buffer: Optional[torch.Tensor] = None if backend in ["fa3", "auto", "trtllm-gen"]: # NOTE(Zihao): assume maximum accumulate kv length is 16M @@ -1444,7 +1444,7 @@ def __init__( self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape, dtype=self._int_workspace_buffer.dtype, - device="cpu", + device="cuda", pin_memory=True, ) self._use_cuda_graph = use_cuda_graph @@ -1516,7 +1516,7 @@ def reset_workspace_buffer( self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape, dtype=self._int_workspace_buffer.dtype, - device="cpu", + device="cuda", pin_memory=True, ) @@ -1765,7 +1765,7 @@ def plan( ) self._paged_kv_indices_buf[: len(paged_kv_indices)].copy_( paged_kv_indices, - non_blocking=(paged_kv_indices.device == self.device) and non_blocking, + non_blocking=(paged_kv_indices.place == self.device) and non_blocking, ) if packed_custom_mask is not None: @@ -1779,7 +1779,7 @@ def plan( ) self._custom_mask_buf[: len(packed_custom_mask)].copy_( packed_custom_mask, - non_blocking=(packed_custom_mask.device == self.device) + non_blocking=(packed_custom_mask.place == self.device) and non_blocking, ) # NOTE(Zihao): mask_indptr has the same length as qo_indptr @@ -1844,7 +1844,7 @@ def plan( vector_sparse_indptr_host = torch.cat( [ torch.tensor( - [0], dtype=torch.int32, device=kv_lens_arr_host.device + [0], dtype=torch.int32, device=kv_lens_arr_host.place ), torch.cumsum(kv_lens_arr_host, dim=0, dtype=torch.int32), ], @@ -2036,7 +2036,7 @@ def run( * The logsumexp of attention output, shape: ``[qo_indptr[-1], num_qo_heads]``. """ if enable_pdl is None: - enable_pdl = device_support_pdl(q.device) + enable_pdl = device_support_pdl(q.place) k_cache, v_cache = _unpack_paged_kv_cache(paged_kv_cache, self._kv_layout) _check_cached_qkv_data_type( q, k_cache, self._cached_q_data_type, self._cached_kv_data_type @@ -2072,20 +2072,20 @@ def run( if return_lse: if lse is None: lse = torch.empty( - (q.size(0), q.size(1)), dtype=torch.float32, device=q.device + (q.size(0), q.size(1)), dtype=torch.float32, device=q.place ) else: check_shape_dtype_device( - lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse" + lse, (q.size(0), q.size(1)), torch.float32, q.place, "lse" ) if out is None: out = torch.empty( - q.shape[:-1] + v_cache.shape[-1:], dtype=q.dtype, device=q.device + q.shape[:-1] + v_cache.shape[-1:], dtype=q.dtype, device=q.place ) else: check_shape_dtype_device( - out, q.shape[:-1] + v_cache.shape[-1:], q.dtype, q.device, "out" + out, q.shape[:-1] + v_cache.shape[-1:], q.dtype, q.place, "out" ) if self._custom_mask_buf is not None: @@ -2169,7 +2169,7 @@ def run( run_args += [ self._custom_mask_buf, self._mask_indptr_buf, - _get_cache_alibi_slopes_buf(q.shape[1], q.device), + _get_cache_alibi_slopes_buf(q.shape[1], q.place), self._prefix_len_ptr, self._token_pos_in_items_ptr, self._max_item_len_ptr, @@ -2420,7 +2420,7 @@ def __init__( self._kv_layout = kv_layout self._float_workspace_buffer = float_workspace_buffer - self.device = float_workspace_buffer.device + self.device = float_workspace_buffer.place self._int_workspace_buffer = torch.empty( (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device ) @@ -2428,7 +2428,7 @@ def __init__( self._int_workspace_buffer.shape, dtype=torch.uint8, pin_memory=True, - device="cpu", + device="cuda", ) self._use_cuda_graph = use_cuda_graph if use_cuda_graph: @@ -2482,7 +2482,7 @@ def reset_workspace_buffer( self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape, dtype=self._int_workspace_buffer.dtype, - device="cpu", + device="cuda", pin_memory=True, ) @@ -2727,7 +2727,7 @@ def plan( if self._backend == "cutlass": # insert qo_indptr.device to 9th position (0-indexed) of get_module_args new_get_module_args = ( - get_module_args[:9] + (qo_indptr.device,) + get_module_args[9:] + get_module_args[:9] + (qo_indptr.place,) + get_module_args[9:] ) self._cached_module = get_fmha_module(*new_get_module_args) else: @@ -2872,7 +2872,7 @@ def run( * The logsumexp of attention output, shape: ``[qo_indptr[-1], num_qo_heads]``. """ if enable_pdl is None: - enable_pdl = device_support_pdl(q.device) + enable_pdl = device_support_pdl(q.place) _check_cached_qkv_data_type( q, k, self._cached_q_data_type, self._cached_kv_data_type ) @@ -2893,19 +2893,19 @@ def run( if return_lse: if lse is None: lse = torch.empty( - (q.size(0), q.size(1)), dtype=torch.float32, device=q.device + (q.size(0), q.size(1)), dtype=torch.float32, device=q.place ) else: check_shape_dtype_device( - lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse" + lse, (q.size(0), q.size(1)), torch.float32, q.place, "lse" ) if out is None: out = torch.empty( - q.shape[:-1] + v.shape[-1:], dtype=q.dtype, device=q.device + q.shape[:-1] + v.shape[-1:], dtype=q.dtype, device=q.place ) else: check_shape_dtype_device( - out, q.shape[:-1] + v.shape[-1:], q.dtype, q.device, "out" + out, q.shape[:-1] + v.shape[-1:], q.dtype, q.place, "out" ) if self._backend == "cutlass": out, lse = fmha_varlen( @@ -3017,19 +3017,19 @@ def fmha_varlen_plan( causal: bool, ): num_ctas = torch.cuda.get_device_properties( - qo_segment_offsets.device + qo_segment_offsets.place ).multi_processor_count work_indptr = torch.empty( - num_ctas + 1, device=qo_segment_offsets.device, dtype=torch.int32 + num_ctas + 1, device=qo_segment_offsets.place, dtype=torch.int32 ) qo_tile_indices = torch.empty( - 131072, device=qo_segment_offsets.device, dtype=torch.int32 + 131072, device=qo_segment_offsets.place, dtype=torch.int32 ) head_indices = torch.empty( - 131072, device=qo_segment_offsets.device, dtype=torch.int32 + 131072, device=qo_segment_offsets.place, dtype=torch.int32 ) batch_indices = torch.empty( - 131072, device=qo_segment_offsets.device, dtype=torch.int32 + 131072, device=qo_segment_offsets.place, dtype=torch.int32 ) module.plan( qo_segment_offsets, @@ -3100,7 +3100,7 @@ def fmha_varlen( return_lse: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: workspace_buffer = _get_cache_buf( - "fmha_varlen_cutlass_workspace", 32 * 1024 * 1024, q.device + "fmha_varlen_cutlass_workspace", 32 * 1024 * 1024, q.place ) module = get_fmha_module( q.dtype, @@ -3112,7 +3112,7 @@ def fmha_varlen( PosEncodingMode.NONE.value, False, # use_sliding_window False, # use_logits_soft_cap - q.device, + q.place, ) nnz_qo, num_qo_heads, head_dim_qk = q.shape @@ -3143,13 +3143,13 @@ def fmha_varlen( qo_total_len + max(max_qo_len, 128), num_qo_heads, head_dim_vo, - device=q.device, + device=q.place, dtype=q.dtype, )[max(max_qo_len, 128) :] if lse is None and return_lse: lse = torch.empty( - qo_total_len, num_qo_heads, device=q.device, dtype=torch.float32 + qo_total_len, num_qo_heads, device=q.place, dtype=torch.float32 ) module.run( @@ -3261,23 +3261,23 @@ def trtllm_ragged_attention_deepseek( ) if enable_pdl is None: - enable_pdl = device_support_pdl(query.device) + enable_pdl = device_support_pdl(query.place) run_func = get_trtllm_gen_fmha_module().trtllm_ragged_attention - sm_count = get_device_sm_count(query.device) + sm_count = get_device_sm_count(query.place) if out is None: out = torch.empty( query.shape[0], query.shape[1], value.shape[2], - device=query.device, + device=query.place, dtype=query.dtype, ) if return_lse and lse is None: lse = torch.empty( query.shape[0], query.shape[1], - device=query.device, + device=query.place, dtype=torch.float32, ) @@ -3381,7 +3381,7 @@ def trtllm_batch_context_with_kv_cache( """ if enable_pdl is None: - enable_pdl = device_support_pdl(query.device) + enable_pdl = device_support_pdl(query.place) if isinstance(kv_cache, tuple): k_cache, v_cache = kv_cache @@ -3397,7 +3397,7 @@ def trtllm_batch_context_with_kv_cache( k_cache, v_cache = kv_cache.unbind(dim=1) run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_context - sm_count = get_device_sm_count(query.device) + sm_count = get_device_sm_count(query.place) if out_dtype == "nvfp4" or (out_dtype is None and isinstance(out, FP4Tensor)): assert query.dtype == torch.float8_e4m3fn, ( @@ -3425,10 +3425,10 @@ def trtllm_batch_context_with_kv_cache( round_up(query.shape[1] * query.shape[2] // o_sf_vec_size, 4), ) out_scale_factor = torch.empty( - fp4_out_scale_shape, dtype=torch.float8_e4m3fn, device=query.device + fp4_out_scale_shape, dtype=torch.float8_e4m3fn, device=query.place ) o_sf_start_index = 0 - out = torch.empty(fp4_out_shape, dtype=torch.uint8, device=query.device) + out = torch.empty(fp4_out_shape, dtype=torch.uint8, device=query.place) else: raise ValueError(f"Invalid out: {out}") @@ -3436,13 +3436,13 @@ def trtllm_batch_context_with_kv_cache( assert isinstance(out, torch.Tensor) # Use uint8 as the container dtype to compliant with next fp4 gemm. - check_shape_dtype_device(out, fp4_out_shape, torch.uint8, query.device, "out") + check_shape_dtype_device(out, fp4_out_shape, torch.uint8, query.place, "out") check_shape_dtype_device( out_scale_factor, fp4_out_scale_shape, torch.float8_e4m3fn, - query.device, + query.place, "out_scale_factor", ) @@ -3467,7 +3467,7 @@ def trtllm_batch_context_with_kv_cache( out = out if out is not None else torch.empty_like(query, dtype=out_dtype) if out_dtype not in (query.dtype, torch.float16, torch.bfloat16): raise ValueError(f"Unsupported out_dtype: {out_dtype}") - check_shape_dtype_device(out, query.shape, out_dtype, query.device, "out") + check_shape_dtype_device(out, query.shape, out_dtype, query.place, "out") else: raise ValueError(f"Invalid out_dtype: {out_dtype}") diff --git a/flashinfer/quantization.py b/flashinfer/quantization.py index 810b1f2ae1..8778b22e75 100644 --- a/flashinfer/quantization.py +++ b/flashinfer/quantization.py @@ -30,7 +30,7 @@ def get_quantization_module(): @register_custom_op("flashinfer::packbits", mutates_args=()) def _packbits(x: torch.Tensor, bitorder: str) -> torch.Tensor: - device = x.device + device = x.place x = x.to(torch.bool) y = torch.empty((x.size(0) + 7) // 8, dtype=torch.uint8, device=device) get_quantization_module().packbits(x, bitorder, y) @@ -39,7 +39,7 @@ def _packbits(x: torch.Tensor, bitorder: str) -> torch.Tensor: @register_fake_op("flashinfer::packbits") def _fake_packbits(x: torch.Tensor, bitorder: str) -> torch.Tensor: - return torch.empty((x.size(0) + 7) // 8, dtype=torch.uint8, device=x.device) + return torch.empty((x.size(0) + 7) // 8, dtype=torch.uint8, device=x.place) def packbits(x: torch.Tensor, bitorder: str = "big") -> torch.Tensor: @@ -124,11 +124,11 @@ def segment_packbits( """ seglen = indptr[1:] - indptr[:-1] packed_len = (seglen + 7) // 8 - indptr_new = torch.zeros(len(indptr), dtype=indptr.dtype, device=indptr.device) + indptr_new = torch.zeros(len(indptr), dtype=indptr.dtype, device=indptr.place) indptr_new[1:] = torch.cumsum(packed_len, 0) output_nnzs = indptr_new[-1].item() - device = x.device + device = x.place indptr = indptr.to(torch.int32) indptr_new = indptr_new.to(torch.int32) y = torch.empty(output_nnzs, dtype=torch.uint8, device=device) diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 2ea1a155e1..63d84c46c2 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -182,7 +182,7 @@ def _unpack_paged_kv_cache( def get_alibi_slopes(n_heads: int) -> torch.Tensor: n = 2 ** math.floor(math.log2(n_heads)) - m_0 = 2.0 ** (-8.0 / n) + m_0 = torch.tensor(2.0 ** (-8.0 / n)) m = torch.pow(m_0, torch.arange(1, 1 + n)) if n < n_heads: m_hat_0 = 2.0 ** (-4.0 / n) @@ -262,7 +262,7 @@ def _check_cached_qkv_data_type( def use_paddle_compatible_api() -> bool: - return os.environ.get("PADDLE_COMPATIBLE_API", "0").lower() in ["1", "on", "true"] + return os.environ.get("PADDLE_COMPATIBLE_API", "1").lower() in ["1", "on", "true"] if use_paddle_compatible_api() or torch.torch_version.TorchVersion( From 93eeb2fa7de8e705f7ddaa1bc69639c5afda9a02 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Fri, 14 Nov 2025 12:05:33 +0800 Subject: [PATCH 9/9] fix cudagraph --- flashinfer/gemm.py | 4 ++-- flashinfer/utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index 009e288067..4a49bfc4fe 100644 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -1728,8 +1728,8 @@ def _check_mm_fp4_problem_size( ) if alpha is not None and alpha.dtype != torch.float: raise ValueError(f"alpha must be a float tensor, got {alpha.dtype}") - if alpha is not None and alpha.numel() != 1: - raise ValueError(f"alpha must be a scalar, got {alpha.numel()}") + # if alpha is not None and alpha.numel() != 1: + # raise ValueError(f"alpha must be a scalar, got {alpha.numel()}") if out_dtype not in (torch.bfloat16, torch.float16): raise ValueError( diff --git a/flashinfer/utils.py b/flashinfer/utils.py index f3ee8e2871..1ad26fd698 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -268,7 +268,7 @@ def _check_cached_qkv_data_type( def use_paddle_compatible_api() -> bool: - return os.environ.get("PADDLE_COMPATIBLE_API", "0").lower() in ["1", "on", "true"] + return os.environ.get("PADDLE_COMPATIBLE_API", "1").lower() in ["1", "on", "true"] if use_paddle_compatible_api() or torch.torch_version.TorchVersion( @@ -1014,7 +1014,7 @@ def wrapper(*args, **kwargs): if tensor_arg is not None: # Get compute capability from the first tensor # Assume all tensors are on the same device/capability - major, minor = get_compute_capability(tensor_arg.device) + major, minor = get_compute_capability(tensor_arg.place) capability = major * 10 + minor if not is_backend_supported(backend, capability):