Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ transforms:
detect_sharding:
stage: sharding
simple_shard_only: false
use_sharding_from_factory: false
support_partial_config: false
use_sharding_from_factory: true
support_partial_config: true
sharding_dims: ['tp', 'ep', 'bmm']
requires_shape_prop: true
# TODO: (hg) need to ensure run_shape_prop after sharding.
Expand Down
50 changes: 49 additions & 1 deletion tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,14 @@
from einops import rearrange
from transformers import AutoModelForCausalLM

from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory
from tensorrt_llm._torch.auto_deploy.models.patches.bamba import _bamba_mixer_torch_forward

# from transformers.models.nemotron_h.configuration_nemotron_h import NemotronHConfig

# Remove this patch after TRT-LLM upgrades to the HF transformers version >= 4.57
# NemotronHConfig.base_model_tp_plan["layers.*.mlp.c_proj"] = "rowwise"


# Forked from:
# https://github.com/state-spaces/mamba/blob/6b32be06d026e170b3fdaf3ae6282c5a6ff57b06/mamba_ssm/ops/triton/layernorm_gated.py
Expand Down Expand Up @@ -79,7 +85,7 @@ def _nemotron_h_block_forward(
elif self.block_type == "attention":
hidden_states = self.mixer(hidden_states, cache_position=cache_position)
hidden_states = hidden_states[0]
elif self.block_type == "mlp":
elif self.block_type in ["mlp", "moe"]:
hidden_states = self.mixer(hidden_states)
else:
raise ValueError(f"Invalid block_type: {self.block_type}")
Expand All @@ -88,6 +94,34 @@ def _nemotron_h_block_forward(
return hidden_states


# TODO: we assume experts have no bias for now
def _nemotron_h_moe_forward(self, hidden_states: torch.Tensor):
"""
Uses NemotronH router (returns indices, weights) and dispatches through auto_deploy::torch_moe_nemo
with act_fn='relu2'. Falls back to original forward if any expert has bias.
"""

residuals = hidden_states
orig_shape = hidden_states.shape
topk_indices, topk_weights = self.gate(hidden_states)
x_flat = hidden_states.view(-1, hidden_states.shape[-1])

out_flat = torch.ops.auto_deploy.torch_moe(
x_flat,
topk_indices,
topk_weights,
w1_weight=[e.up_proj.weight for e in self.experts],
w2_weight=[e.down_proj.weight for e in self.experts],
w3_weight=[],
act_fn="relu2",
mlp_style="mlp",
)

out = out_flat.view(*orig_shape)
out = out + self.shared_experts(residuals)
return out


_from_config_original = AutoModelForCausalLM.from_config

CUSTOM_MODULE_PATCHES: Dict[str, List[Tuple[str, Callable]]] = {
Expand All @@ -97,6 +131,7 @@ def _nemotron_h_block_forward(
("_update_mamba_mask", _nemotron_h_model_update_mamba_mask),
],
"NemotronHBlock": [("forward", _nemotron_h_block_forward)],
"NemotronHMOE": [("forward", _nemotron_h_moe_forward)],
}


Expand All @@ -112,6 +147,19 @@ def get_model_from_config_patched(config, **kwargs):
return model


def _set_sharding_config_patched(self, *args, **kwargs):
self._sharding_config["head_dim"] = 128
self._sharding_config["tp_plan"] = {
"in_proj": "mamba",
"out_proj": "rowwise",
"up_proj": "colwise",
"down_proj": "rowwise",
"*": "gather",
}


AutoModelForCausalLMFactory._set_sharding_config = _set_sharding_config_patched
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will require clean-up


# TODO: figure out how this can be incorporated into the export patch system
AutoModelForCausalLM.from_config = get_model_from_config_patched

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def _find_final_hidden_state_node(
if not (hasattr(mul_node, "args") and len(mul_node.args) >= 2):
return None
index_node = mul_node.args[1]
index_add_node = bfs(
index_add_node, _ = bfs(
index_node, lambda n: is_op(n, torch.ops.aten.index_add_), boundary=end_boundary
)
if not index_add_node:
Expand Down Expand Up @@ -360,7 +360,7 @@ def target(n: torch.fx.Node) -> bool:
return is_op(n, {torch.ops.aten.index_add_}) and len(n.users) == 0

try:
node_to_remove = bfs(start_boundary, target, attr_next="users", boundary=end_boundary)
node_to_remove, _ = bfs(start_boundary, target, attr_next="users", boundary=end_boundary)
graph.erase_node(node_to_remove)
return True
except RuntimeError:
Expand Down Expand Up @@ -430,7 +430,7 @@ def _apply(
common_ancessor2 = _find_lowest_common_ancessor(arg2_list)
if not common_ancessor2:
continue
selected_experts = bfs(
selected_experts, _ = bfs(
common_ancessor2,
lambda node: is_op(node, torch.ops.aten.one_hot),
attr_next="all_input_nodes",
Expand Down
78 changes: 54 additions & 24 deletions tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
happens automatically via the checkpoint loading hook added in step 2c.
"""

import ast
import operator
import re
from collections import defaultdict
Expand All @@ -38,6 +39,7 @@
from ...utils.sharding_utils import (
BMMShardingInfo,
EPShardingInfo,
LayerType,
ShardingConfig,
ShardingTransformInfo,
SplitDimension,
Expand Down Expand Up @@ -165,6 +167,7 @@ def _apply(
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
local_rank, world_size = shared_config.local_rank, shared_config.world_size
# world_size = 2

if world_size < 2:
ad_logger.info("Skipping sharding for single device")
Expand All @@ -173,58 +176,51 @@ def _apply(
)

assert isinstance(gm, GraphModule), "Expecting GraphModule"
shared_config.sharding_config.rank = local_rank
shared_config.sharding_config.world_size = world_size
shared_config.sharding_config.predefined_config = (
factory.get_sharding_config() if factory else {}
)
shared_config.sharding_config.factory_source = (
shared_config.sharding_config.predefined_config.get(
"source", ShardingConfigSource.UNKNOWN
)
sharding_config = shared_config.sharding_config
sharding_config.rank = local_rank
sharding_config.world_size = world_size
sharding_config.predefined_config = factory.get_sharding_config() if factory else {}
sharding_config.factory_source = (
sharding_config.predefined_config.get("source", ShardingConfigSource.UNKNOWN)
if factory
else ShardingConfigSource.UNKNOWN
)
shared_config.sharding_config.simple_shard_only = self.config.simple_shard_only
shared_config.sharding_config.support_partial_config = self.config.support_partial_config
shared_config.sharding_config.sharding_dims = self.config.sharding_dims
sharding_config.simple_shard_only = self.config.simple_shard_only
sharding_config.support_partial_config = self.config.support_partial_config
sharding_config.sharding_dims = self.config.sharding_dims

shared_config.sharding_config.use_sharding_from_factory = (
self.config.use_sharding_from_factory
)
sharding_config.use_sharding_from_factory = self.config.use_sharding_from_factory

sharding_config = shared_config.sharding_config
sharding_config.validate_config()
# sharding_config.predefined_config = predefined_config

if (
shared_config.sharding_config.use_sharding_from_factory
and len(shared_config.sharding_config.get_predefined_config()) > 0
sharding_config.use_sharding_from_factory
and len(sharding_config.get_predefined_config()) > 0
):
ad_logger.info("Applying sharding from config")
factory_info = detect_sharding_from_factory_config(gm, sharding_config)
return gm, factory_info

ad_logger.info(
f"Running autodeploy sharding heuristics: {shared_config.sharding_config.sharding_dims}"
)
ad_logger.info(f"Running autodeploy sharding heuristics: {sharding_config.sharding_dims}")
# run TP sharding across ranks
if "tp" in shared_config.sharding_config.sharding_dims:
if "tp" in sharding_config.sharding_dims:
tp_info = detect_column_row_shard(gm, sharding_config)
else:
tp_info = TransformInfo(
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)

# run EP sharding across ranks
if "ep" in shared_config.sharding_config.sharding_dims:
if "ep" in sharding_config.sharding_dims:
ep_info = detect_ep_shard(gm, sharding_config)
else:
ep_info = TransformInfo(
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)

# run BMM sharding across ranks
if "bmm" in shared_config.sharding_config.sharding_dims:
if "bmm" in sharding_config.sharding_dims:
dp_bmm_info = detect_dp_bmm_shard(gm, sharding_config)
else:
dp_bmm_info = TransformInfo(
Expand Down Expand Up @@ -260,6 +256,7 @@ def detect_sharding_from_factory_config(
# 4. the allowed values are:
# - "colwise"
# - "rowwise"
# - "mamba"
# - "sequence_parallel"
# - "local_colwise"
# - "local_rowwise"
Expand Down Expand Up @@ -313,6 +310,24 @@ def detect_sharding_from_factory_config(
num_shards += 1
# we have a match. Get the config for this layer
config = tp_plan[key]
# check if config has parameters.
if "(" in config:
config, params_str = config.split("(", 1)
params_str = params_str.rsplit(")", 1)[0] # Remove trailing )

try:
# Convert "key" = value to "key": value format for dict parsing
params_str = params_str.replace(" = ", ": ")
# Wrap in braces to make it a dict and parse
config_params = ast.literal_eval("{" + params_str + "}")
except Exception as e:
ad_logger.warning(
f"Failed to parse config params: {params_str}, error: {e}. "
"Using empty config."
)
config_params = {}
else:
config_params = {}
if config == "colwise":
sharding_config.tp_transforms.append(
TPShardingInfo.from_node(
Expand All @@ -324,6 +339,7 @@ def detect_sharding_from_factory_config(
min_local_shape=min_local_shape,
)
)
num_row_col_shards += 1
elif config == "rowwise":
sharding_config.tp_transforms.append(
TPShardingInfo.from_node(
Expand All @@ -336,6 +352,20 @@ def detect_sharding_from_factory_config(
)
)
num_row_col_shards += 1
elif config == "mamba":
sharding_config.tp_transforms.append(
TPShardingInfo.from_node(
lin_node,
split_dim=SplitDimension.COLUMN,
rank=rank,
world_size=world_size,
dist_op=None,
min_local_shape=min_local_shape,
layer_type=LayerType.MAMBA,
fused_weight_dims=config_params.get("fused_weight_dims"),
)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my mind, a complete specification for col-row shard of a Mamba-like layer has 4 entries in the sharding_config that are passed to the sharding_executor:

  1. column-wise with appropriate fused_weight_dims for the in_proj
  2. column-wise with appropriate fused_weight_dims for torch_causal_conv1d
  3. column-wise for the weight on the gated rms norm
  4. row-wise for out_proj

More specifically, I don't think we should add special handling for mamba via something like the layer_type argument.

Now there is two ways we can get these four entries:

  1. Sharding heuristic that analyzes the model and correctly configures those four entries. Seems like you already have something like this here. This can be repurposed as detect_mamba_sharding to add the appropriate entries into the sharding_config
  2. Manual sharding config: this is pending your [TRTLLM-6342][feat] Support custom sharding config source #8153 PR and so might not be a good alternative for now. Even with 2. in place though, it wouldn't be the ideal solution since the manual sharding config would have to specify the fused_weight_dims which is not ideal either. So since we have 1. in place we can go with this option for now.

In terms of getting this merged to main I would suggest the following split:

  1. Ability for the sharding executor to understand TPShardingInfo with fused_weight_dims. Moreover, we probably also need to ensure that causal conv and element-wise multiplication from the gated rms norm can be supported in the TPShardingInfo specification. + Unit tests
  2. Add a new sharding heuristic for the mamba layer that can auto-detect the correct sharding_config entry for mamba-like layers

In the meantime, the current branch can be used to test+benchmark Nemotron nano-v3. So please make sure it remains available for testing

num_row_col_shards += 1
elif "sequence" in config:
# TODO: Sequence parallelism is not supported yet.
ad_logger.warning("Sequence parallelism is not supported yet. Skipping.")
Expand Down
Loading
Loading