diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index 775ef628d2d..18da8b76b07 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -29,6 +29,7 @@ class CacheConfig: """A dataclass to hold information how to configure the cache.""" dtype: Optional[torch.dtype] = None + mamba_dtype: Optional[torch.dtype] = None class SequenceInfo: diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py index ccd24e7ec00..0908e7c9fb1 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py @@ -325,7 +325,7 @@ def get_cached_attention_op(cls) -> MHACallable: @classmethod def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]: - # Returns (seq_len, seq_start, slot_idx) + # Returns (seq_len, seq_start, slot_idx, use_initial_states) return torch.ops.auto_deploy.torch_ssm_prepare_metadata, 4 @classmethod @@ -339,6 +339,9 @@ def get_cache_initializers( num_heads = hs_fake.shape[-2] head_dim = hs_fake.shape[-1] + # dtype from node itself + dtype = source_attn_node.meta["val"].dtype + # Infer state size by assuming B has shape [b, s, n_groups * ssm_state_size] # During runtime we pass [b, s, n_groups, ssm_state_size]; both give the same last dim product. if B_fake.ndim >= 4: @@ -354,7 +357,7 @@ def _get_ssm_cache(si: SequenceInfo): head_dim, ssm_state_size, device=si.device, - dtype=cache_config.dtype or hs_fake.dtype, + dtype=cache_config.mamba_dtype or dtype, ) return {"ssm_state_cache": _get_ssm_cache} diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py index 64b62419162..630702895da 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py @@ -1,27 +1,14 @@ -from typing import List, Tuple +from typing import List import torch -from torch._ops import OpOverloadPacket -from torch.fx import Node # Triton kernels from tensorrt_llm._torch.modules.mamba.mamba2_metadata import cu_seqlens_to_chunk_indices_offsets from tensorrt_llm._torch.modules.mamba.selective_state_update import selective_state_update from tensorrt_llm._torch.modules.mamba.ssd_combined import mamba_chunk_scan_combined -from ...utils.node_utils import extract_op_args -from ..attention_interface import ( - AttentionDescriptor, - AttentionLayout, - AttentionRegistry, - BufferInitializerDict, - CacheConfig, - CacheInitializerDict, - Constant, - MHACallable, - PrepareMetadataCallable, - SequenceInfo, -) +from ..attention_interface import AttentionRegistry, MHACallable +from .torch_backend_mamba import TorchBackendSSM @torch.library.custom_op("auto_deploy::triton_cached_ssm", mutates_args={}) @@ -202,70 +189,7 @@ def _triton_cached_ssm_fake( @AttentionRegistry.register("triton_ssm") -class TritonBackendSSM(AttentionDescriptor): - @classmethod - def is_paged(cls) -> bool: - return True - - @classmethod - def get_attention_layout(cls) -> AttentionLayout: - # Hidden states follow [b, s, n, d] - return "bsnd" - - @classmethod - def get_num_qkv_args(cls) -> int: - # torch_ssm_transform signature has 7 node/state arguments - return 7 - - @classmethod - def get_source_attention_op(cls) -> OpOverloadPacket: - # Keep source op unchanged (used for uncached pre-export) - return torch.ops.auto_deploy.torch_ssm - +class TritonBackendSSM(TorchBackendSSM): @classmethod def get_cached_attention_op(cls) -> MHACallable: return torch.ops.auto_deploy.triton_cached_ssm - - @classmethod - def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]: - # Returns (seq_len, seq_start, slot_idx, use_initial_states) - return torch.ops.auto_deploy.torch_ssm_prepare_metadata, 4 - - @classmethod - def get_cache_initializers( - cls, source_attn_node: Node, cache_config: CacheConfig - ) -> CacheInitializerDict: - # Shapes from fake tensors - hs_fake: torch.Tensor = source_attn_node.args[0].meta["val"] - B_fake: torch.Tensor = source_attn_node.args[2].meta["val"] - - num_heads = hs_fake.shape[-2] - head_dim = hs_fake.shape[-1] - - if B_fake.ndim >= 4: - ssm_state_size = B_fake.shape[-1] - else: - ssm_state_size = max(1, B_fake.shape[-1]) - - def _get_ssm_cache(si: SequenceInfo): - return torch.empty( - si.max_batch_size, - num_heads, - head_dim, - ssm_state_size, - device=si.device, - dtype=cache_config.dtype or hs_fake.dtype, - ) - - return {"ssm_state_cache": _get_ssm_cache} - - @classmethod - def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict: - return {} - - @classmethod - def get_constants(cls, source_attn_node: Node) -> List[Constant]: - time_step_limit, chunk_size = extract_op_args( - source_attn_node, "time_step_limit", "chunk_size" - ) - return [time_step_limit, chunk_size]