Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class CacheConfig:
"""A dataclass to hold information how to configure the cache."""

dtype: Optional[torch.dtype] = None
mamba_dtype: Optional[torch.dtype] = None


class SequenceInfo:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def get_cached_attention_op(cls) -> MHACallable:

@classmethod
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
# Returns (seq_len, seq_start, slot_idx)
# Returns (seq_len, seq_start, slot_idx, use_initial_states)
return torch.ops.auto_deploy.torch_ssm_prepare_metadata, 4

@classmethod
Expand All @@ -339,6 +339,9 @@ def get_cache_initializers(
num_heads = hs_fake.shape[-2]
head_dim = hs_fake.shape[-1]

# dtype from node itself
dtype = source_attn_node.meta["val"].dtype

# Infer state size by assuming B has shape [b, s, n_groups * ssm_state_size]
# During runtime we pass [b, s, n_groups, ssm_state_size]; both give the same last dim product.
if B_fake.ndim >= 4:
Expand All @@ -354,7 +357,7 @@ def _get_ssm_cache(si: SequenceInfo):
head_dim,
ssm_state_size,
device=si.device,
dtype=cache_config.dtype or hs_fake.dtype,
dtype=cache_config.mamba_dtype or dtype,
)

return {"ssm_state_cache": _get_ssm_cache}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,27 +1,14 @@
from typing import List, Tuple
from typing import List

import torch
from torch._ops import OpOverloadPacket
from torch.fx import Node

# Triton kernels
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import cu_seqlens_to_chunk_indices_offsets
from tensorrt_llm._torch.modules.mamba.selective_state_update import selective_state_update
from tensorrt_llm._torch.modules.mamba.ssd_combined import mamba_chunk_scan_combined

from ...utils.node_utils import extract_op_args
from ..attention_interface import (
AttentionDescriptor,
AttentionLayout,
AttentionRegistry,
BufferInitializerDict,
CacheConfig,
CacheInitializerDict,
Constant,
MHACallable,
PrepareMetadataCallable,
SequenceInfo,
)
from ..attention_interface import AttentionRegistry, MHACallable
from .torch_backend_mamba import TorchBackendSSM


@torch.library.custom_op("auto_deploy::triton_cached_ssm", mutates_args={})
Expand Down Expand Up @@ -202,70 +189,7 @@ def _triton_cached_ssm_fake(


@AttentionRegistry.register("triton_ssm")
class TritonBackendSSM(AttentionDescriptor):
@classmethod
def is_paged(cls) -> bool:
return True

@classmethod
def get_attention_layout(cls) -> AttentionLayout:
# Hidden states follow [b, s, n, d]
return "bsnd"

@classmethod
def get_num_qkv_args(cls) -> int:
# torch_ssm_transform signature has 7 node/state arguments
return 7

@classmethod
def get_source_attention_op(cls) -> OpOverloadPacket:
# Keep source op unchanged (used for uncached pre-export)
return torch.ops.auto_deploy.torch_ssm

class TritonBackendSSM(TorchBackendSSM):
@classmethod
def get_cached_attention_op(cls) -> MHACallable:
return torch.ops.auto_deploy.triton_cached_ssm

@classmethod
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
# Returns (seq_len, seq_start, slot_idx, use_initial_states)
return torch.ops.auto_deploy.torch_ssm_prepare_metadata, 4

@classmethod
def get_cache_initializers(
cls, source_attn_node: Node, cache_config: CacheConfig
) -> CacheInitializerDict:
# Shapes from fake tensors
hs_fake: torch.Tensor = source_attn_node.args[0].meta["val"]
B_fake: torch.Tensor = source_attn_node.args[2].meta["val"]

num_heads = hs_fake.shape[-2]
head_dim = hs_fake.shape[-1]

if B_fake.ndim >= 4:
ssm_state_size = B_fake.shape[-1]
else:
ssm_state_size = max(1, B_fake.shape[-1])

def _get_ssm_cache(si: SequenceInfo):
return torch.empty(
si.max_batch_size,
num_heads,
head_dim,
ssm_state_size,
device=si.device,
dtype=cache_config.dtype or hs_fake.dtype,
)

return {"ssm_state_cache": _get_ssm_cache}

@classmethod
def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict:
return {}

@classmethod
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
time_step_limit, chunk_size = extract_op_args(
source_attn_node, "time_step_limit", "chunk_size"
)
return [time_step_limit, chunk_size]
Loading