Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flashinfer/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ def _create_tensor_like(
but with dimensions specified by the dims parameter.
"""
dtype = origin_tensor.dtype
device = origin_tensor.device
device = origin_tensor.place
shapes = []
for d in dims:
if isinstance(d, StaticDim):
Expand Down
68 changes: 34 additions & 34 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def single_decode_with_kv_cache_with_jit_module(
window_left: int = -1,
return_lse: bool = False,
):
device = q.device
device = q.place
tmp = _get_cache_buf("single_decode_with_kv_cache_tmp", 32 * 1024 * 1024, device)
o = torch.empty_like(q)
if return_lse:
Expand Down Expand Up @@ -483,7 +483,7 @@ def single_decode_with_kv_cache(
"""
_check_pos_encoding_mode(pos_encoding_mode)
_check_kv_layout(kv_layout)
tmp = _get_cache_buf("single_decode_with_kv_cache_tmp", 32 * 1024 * 1024, q.device)
tmp = _get_cache_buf("single_decode_with_kv_cache_tmp", 32 * 1024 * 1024, q.place)
head_dim = q.shape[-1]
if logits_soft_cap is None:
logits_soft_cap = 0.0
Expand All @@ -501,7 +501,7 @@ def single_decode_with_kv_cache(

lse = None
if return_lse:
lse = torch.empty((num_qo_heads,), dtype=torch.float32, device=q.device)
lse = torch.empty((num_qo_heads,), dtype=torch.float32, device=q.place)

if use_tensor_cores:
out = torch.empty_like(q.unsqueeze(0))
Expand All @@ -527,7 +527,7 @@ def single_decode_with_kv_cache(
TensorLayout[kv_layout].value,
window_left,
None, # packed_custom_mask
_get_cache_alibi_slopes_buf(num_qo_heads, q.device),
_get_cache_alibi_slopes_buf(num_qo_heads, q.place),
logits_soft_cap,
sm_scale,
None, # scale_q, not supported yet
Expand Down Expand Up @@ -557,7 +557,7 @@ def single_decode_with_kv_cache(
tmp,
out,
lse,
_get_cache_alibi_slopes_buf(num_qo_heads, q.device),
_get_cache_alibi_slopes_buf(num_qo_heads, q.place),
TensorLayout[kv_layout].value,
window_left,
logits_soft_cap,
Expand Down Expand Up @@ -722,15 +722,15 @@ def __init__(

self._kv_layout = kv_layout
self._float_workspace_buffer = float_workspace_buffer
self.device = float_workspace_buffer.device
self.device = float_workspace_buffer.place
self._int_workspace_buffer = torch.empty(
(8 * 1024 * 1024,), dtype=torch.uint8, device=self.device
)
self._pin_memory_int_workspace_buffer = torch.empty(
(8 * 1024 * 1024,),
dtype=torch.uint8,
pin_memory=True,
device="cpu",
device="cuda",
)
self._kv_lens_buffer: Optional[torch.Tensor] = None
if backend == "trtllm-gen":
Expand Down Expand Up @@ -771,7 +771,7 @@ def __init__(
self._qo_indptr_buf = torch.arange(
self._fixed_batch_size + 1,
dtype=torch.int32,
device=float_workspace_buffer.device,
device=float_workspace_buffer.place,
)
self._backend = backend

Expand Down Expand Up @@ -803,7 +803,7 @@ def reset_workspace_buffer(
self._pin_memory_int_workspace_buffer = torch.empty(
self._int_workspace_buffer.shape,
dtype=self._int_workspace_buffer.dtype,
device="cpu",
device="cuda",
pin_memory=True,
)

Expand Down Expand Up @@ -934,7 +934,7 @@ def plan(
last_page_len, non_blocking=non_blocking
)
self._paged_kv_indices_buf[: len(indices)].copy_(
indices, non_blocking=(indices.device == self.device) and non_blocking
indices, non_blocking=(indices.place == self.device) and non_blocking
)
else:
self._paged_kv_indptr_buf = indptr.to(
Expand Down Expand Up @@ -1224,7 +1224,7 @@ def run(
* logsumexp of attention scores, shape: ``[batch_size, num_qo_heads]``.
"""
if enable_pdl is None:
enable_pdl = device_support_pdl(q.device)
enable_pdl = device_support_pdl(q.place)
k_cache, v_cache = _unpack_paged_kv_cache(paged_kv_cache, self._kv_layout)
if self._kv_layout == "NHD":
page_size = k_cache.shape[1]
Expand Down Expand Up @@ -1262,17 +1262,17 @@ def run(
if return_lse:
if lse is None:
lse = torch.empty(
(q.size(0), q.size(1)), dtype=torch.float32, device=q.device
(q.size(0), q.size(1)), dtype=torch.float32, device=q.place
)
else:
check_shape_dtype_device(
lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
lse, (q.size(0), q.size(1)), torch.float32, q.place, "lse"
)

if out is None:
out = torch.empty_like(q)
else:
check_shape_dtype_device(out, q.shape, q.dtype, q.device, "out")
check_shape_dtype_device(out, q.shape, q.dtype, q.place, "out")

if self._backend == "trtllm-gen":
q = q.view(q.size(0) // q_len_per_req, q_len_per_req, q.size(1), q.size(2))
Expand Down Expand Up @@ -1303,7 +1303,7 @@ def run(
run_args += [
None, # packed_custom_mask
None, # mask_indptr_buf
_get_cache_alibi_slopes_buf(q.shape[1], q.device),
_get_cache_alibi_slopes_buf(q.shape[1], q.place),
None, # maybe_prefix_len_ptr
None, # maybe_token_pos_in_items_ptr
None, # maybe_max_item_len_ptr
Expand Down Expand Up @@ -1356,7 +1356,7 @@ def run(
run_args.extend(list(args))
else:
run_args += [
_get_cache_alibi_slopes_buf(q.shape[1], q.device),
_get_cache_alibi_slopes_buf(q.shape[1], q.place),
logits_soft_cap,
sm_scale,
rope_scale,
Expand Down Expand Up @@ -1530,15 +1530,15 @@ def __init__(
Only needed when ``use_cuda_graph`` is ``True``.
"""
self._float_workspace_buffer = float_workspace_buffer
self.device = float_workspace_buffer.device
self.device = float_workspace_buffer.place
self._int_workspace_buffer = torch.empty(
(8 * 1024 * 1024,), dtype=torch.uint8, device=self.device
)
self._pin_memory_int_workspace_buffer = torch.empty(
(8 * 1024 * 1024,),
dtype=torch.uint8,
pin_memory=True,
device="cpu",
device="cuda",
)

if use_cuda_graph:
Expand Down Expand Up @@ -1596,7 +1596,7 @@ def reset_workspace_buffer(
self._pin_memory_int_workspace_buffer = torch.empty(
self._int_workspace_buffer.shape,
dtype=self._int_workspace_buffer.dtype,
device="cpu",
device="cuda",
pin_memory=True,
)

Expand Down Expand Up @@ -1779,7 +1779,7 @@ def run(
"""

# MLA decode kernel supports SM80 only
major, minor = get_compute_capability(q_nope.device)
major, minor = get_compute_capability(q_nope.place)
device_arch = major * 10 + minor
if device_arch != 80:
raise GPUArchitectureError(
Expand Down Expand Up @@ -1807,7 +1807,7 @@ def run(
out = torch.empty_like(q_nope, device=device)
else:
check_shape_dtype_device(
out, q_nope.shape, q_nope.dtype, q_nope.device, "out"
out, q_nope.shape, q_nope.dtype, q_nope.place, "out"
)

if return_lse:
Expand All @@ -1822,7 +1822,7 @@ def run(
lse,
(q_nope.size(0), q_nope.size(1)),
q_nope.dtype,
q_nope.device,
q_nope.place,
"lse",
)
self._cached_module.run(
Expand Down Expand Up @@ -1883,7 +1883,7 @@ def _paged_run(
if out is None:
out = torch.empty_like(query)
if self._sm_count is None:
self._sm_count = get_device_sm_count(query.device)
self._sm_count = get_device_sm_count(query.place)

bmm1_scale = (
bmm1_scale.item() if isinstance(bmm1_scale, torch.Tensor) else bmm1_scale
Expand Down Expand Up @@ -2125,7 +2125,7 @@ def trtllm_batch_decode_with_kv_cache(
out : Union[torch.Tensor, FP4Tensor]
output torch.Tensor or FP4Tensor.
"""
enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl
enable_pdl = device_support_pdl(query.place) if enable_pdl is None else enable_pdl

if isinstance(kv_cache, tuple):
k_cache, v_cache = kv_cache
Expand All @@ -2141,7 +2141,7 @@ def trtllm_batch_decode_with_kv_cache(
k_cache, v_cache = kv_cache.unbind(dim=1)

run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_decode
sm_count = get_device_sm_count(query.device)
sm_count = get_device_sm_count(query.place)

if out_dtype == "nvfp4" or (out_dtype is None and isinstance(out, FP4Tensor)):
assert query.dtype == torch.float8_e4m3fn, (
Expand Down Expand Up @@ -2169,24 +2169,24 @@ def trtllm_batch_decode_with_kv_cache(
round_up(query.shape[1] * query.shape[2] // o_sf_vec_size, 4),
)
out_scale_factor = torch.empty(
fp4_out_scale_shape, dtype=torch.float8_e4m3fn, device=query.device
fp4_out_scale_shape, dtype=torch.float8_e4m3fn, device=query.place
)
o_sf_start_index = 0
out = torch.empty(fp4_out_shape, dtype=torch.uint8, device=query.device)
out = torch.empty(fp4_out_shape, dtype=torch.uint8, device=query.place)
else:
raise ValueError(f"Invalid out: {out}")

assert out_dtype == "nvfp4"
assert isinstance(out, torch.Tensor)

# Use uint8 as the container dtype to compliant with next fp4 gemm.
check_shape_dtype_device(out, fp4_out_shape, torch.uint8, query.device, "out")
check_shape_dtype_device(out, fp4_out_shape, torch.uint8, query.place, "out")

check_shape_dtype_device(
out_scale_factor,
fp4_out_scale_shape,
torch.float8_e4m3fn,
query.device,
query.place,
"out_scale_factor",
)

Expand All @@ -2211,7 +2211,7 @@ def trtllm_batch_decode_with_kv_cache(
out = out if out is not None else torch.empty_like(query, dtype=out_dtype)
if out_dtype not in (query.dtype, torch.float16, torch.bfloat16):
raise ValueError(f"Unsupported out_dtype: {out_dtype}")
check_shape_dtype_device(out, query.shape, out_dtype, query.device, "out")
check_shape_dtype_device(out, query.shape, out_dtype, query.place, "out")
else:
raise ValueError(f"Invalid out_dtype: {out_dtype}")

Expand Down Expand Up @@ -2349,9 +2349,9 @@ def trtllm_batch_decode_with_kv_cache_mla(
- Currently, only fp8 tensor core operation supports this mode.
When both are provided, the dynamic scale factor tensors will be used.
"""
enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl
enable_pdl = device_support_pdl(query.place) if enable_pdl is None else enable_pdl
run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_decode
sm_count = get_device_sm_count(query.device)
sm_count = get_device_sm_count(query.place)

block_size = kv_cache.size(-2)
if (
Expand All @@ -2371,14 +2371,14 @@ def trtllm_batch_decode_with_kv_cache_mla(

if out is None:
out_shape = query.shape[:-1] + (kv_lora_rank,)
out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device)
out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.place)
else:
batch_size, _, num_q_heads, _ = query.shape
check_shape_dtype_device(
out,
[batch_size, num_q_heads, kv_lora_rank],
torch.bfloat16,
query.device,
query.place,
"out",
)

Expand Down
Loading
Loading