-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[TRTLLM-8201][feat] TP sharding of Mamba layers #8548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
1b312b7
533f709
a21642c
609dca8
ad7364b
3c93381
bffef4b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,8 +9,14 @@ | |
| from einops import rearrange | ||
| from transformers import AutoModelForCausalLM | ||
|
|
||
| from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory | ||
| from tensorrt_llm._torch.auto_deploy.models.patches.bamba import _bamba_mixer_torch_forward | ||
|
|
||
| # from transformers.models.nemotron_h.configuration_nemotron_h import NemotronHConfig | ||
|
|
||
| # Remove this patch after TRT-LLM upgrades to the HF transformers version >= 4.57 | ||
| # NemotronHConfig.base_model_tp_plan["layers.*.mlp.c_proj"] = "rowwise" | ||
greg-kwasniewski1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| # Forked from: | ||
| # https://github.com/state-spaces/mamba/blob/6b32be06d026e170b3fdaf3ae6282c5a6ff57b06/mamba_ssm/ops/triton/layernorm_gated.py | ||
|
|
@@ -79,7 +85,7 @@ def _nemotron_h_block_forward( | |
| elif self.block_type == "attention": | ||
| hidden_states = self.mixer(hidden_states, cache_position=cache_position) | ||
| hidden_states = hidden_states[0] | ||
| elif self.block_type == "mlp": | ||
| elif self.block_type in ["mlp", "moe"]: | ||
| hidden_states = self.mixer(hidden_states) | ||
| else: | ||
| raise ValueError(f"Invalid block_type: {self.block_type}") | ||
|
|
@@ -88,6 +94,34 @@ def _nemotron_h_block_forward( | |
| return hidden_states | ||
|
|
||
|
|
||
| # TODO: we assume experts have no bias for now | ||
| def _nemotron_h_moe_forward(self, hidden_states: torch.Tensor): | ||
greg-kwasniewski1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| Uses NemotronH router (returns indices, weights) and dispatches through auto_deploy::torch_moe_nemo | ||
| with act_fn='relu2'. Falls back to original forward if any expert has bias. | ||
| """ | ||
|
|
||
| residuals = hidden_states | ||
| orig_shape = hidden_states.shape | ||
| topk_indices, topk_weights = self.gate(hidden_states) | ||
| x_flat = hidden_states.view(-1, hidden_states.shape[-1]) | ||
|
|
||
| out_flat = torch.ops.auto_deploy.torch_moe( | ||
| x_flat, | ||
| topk_indices, | ||
| topk_weights, | ||
| w1_weight=[e.up_proj.weight for e in self.experts], | ||
| w2_weight=[e.down_proj.weight for e in self.experts], | ||
| w3_weight=[], | ||
| act_fn="relu2", | ||
| mlp_style="mlp", | ||
| ) | ||
|
|
||
| out = out_flat.view(*orig_shape) | ||
| out = out + self.shared_experts(residuals) | ||
| return out | ||
|
|
||
|
|
||
| _from_config_original = AutoModelForCausalLM.from_config | ||
|
|
||
| CUSTOM_MODULE_PATCHES: Dict[str, List[Tuple[str, Callable]]] = { | ||
|
|
@@ -97,6 +131,7 @@ def _nemotron_h_block_forward( | |
| ("_update_mamba_mask", _nemotron_h_model_update_mamba_mask), | ||
| ], | ||
| "NemotronHBlock": [("forward", _nemotron_h_block_forward)], | ||
| "NemotronHMOE": [("forward", _nemotron_h_moe_forward)], | ||
| } | ||
|
|
||
|
|
||
|
|
@@ -112,6 +147,19 @@ def get_model_from_config_patched(config, **kwargs): | |
| return model | ||
|
|
||
|
|
||
| def _set_sharding_config_patched(self, *args, **kwargs): | ||
| self._sharding_config["head_dim"] = 128 | ||
| self._sharding_config["tp_plan"] = { | ||
| "in_proj": "mamba", | ||
| "out_proj": "rowwise", | ||
| "up_proj": "colwise", | ||
| "down_proj": "rowwise", | ||
| "*": "gather", | ||
| } | ||
|
|
||
|
|
||
| AutoModelForCausalLMFactory._set_sharding_config = _set_sharding_config_patched | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will require clean-up |
||
|
|
||
| # TODO: figure out how this can be incorporated into the export patch system | ||
| AutoModelForCausalLM.from_config = get_model_from_config_patched | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |
| happens automatically via the checkpoint loading hook added in step 2c. | ||
| """ | ||
|
|
||
| import ast | ||
| import operator | ||
| import re | ||
| from collections import defaultdict | ||
|
|
@@ -38,6 +39,7 @@ | |
| from ...utils.sharding_utils import ( | ||
| BMMShardingInfo, | ||
| EPShardingInfo, | ||
| LayerType, | ||
| ShardingConfig, | ||
| ShardingTransformInfo, | ||
| SplitDimension, | ||
|
|
@@ -165,6 +167,7 @@ def _apply( | |
| shared_config: SharedConfig, | ||
| ) -> Tuple[GraphModule, TransformInfo]: | ||
| local_rank, world_size = shared_config.local_rank, shared_config.world_size | ||
| # world_size = 2 | ||
|
|
||
| if world_size < 2: | ||
| ad_logger.info("Skipping sharding for single device") | ||
|
|
@@ -173,58 +176,51 @@ def _apply( | |
| ) | ||
|
|
||
| assert isinstance(gm, GraphModule), "Expecting GraphModule" | ||
| shared_config.sharding_config.rank = local_rank | ||
| shared_config.sharding_config.world_size = world_size | ||
| shared_config.sharding_config.predefined_config = ( | ||
| factory.get_sharding_config() if factory else {} | ||
| ) | ||
| shared_config.sharding_config.factory_source = ( | ||
| shared_config.sharding_config.predefined_config.get( | ||
| "source", ShardingConfigSource.UNKNOWN | ||
| ) | ||
| sharding_config = shared_config.sharding_config | ||
| sharding_config.rank = local_rank | ||
| sharding_config.world_size = world_size | ||
| sharding_config.predefined_config = factory.get_sharding_config() if factory else {} | ||
| sharding_config.factory_source = ( | ||
| sharding_config.predefined_config.get("source", ShardingConfigSource.UNKNOWN) | ||
| if factory | ||
| else ShardingConfigSource.UNKNOWN | ||
| ) | ||
| shared_config.sharding_config.simple_shard_only = self.config.simple_shard_only | ||
| shared_config.sharding_config.support_partial_config = self.config.support_partial_config | ||
| shared_config.sharding_config.sharding_dims = self.config.sharding_dims | ||
| sharding_config.simple_shard_only = self.config.simple_shard_only | ||
| sharding_config.support_partial_config = self.config.support_partial_config | ||
| sharding_config.sharding_dims = self.config.sharding_dims | ||
|
|
||
| shared_config.sharding_config.use_sharding_from_factory = ( | ||
| self.config.use_sharding_from_factory | ||
| ) | ||
| sharding_config.use_sharding_from_factory = self.config.use_sharding_from_factory | ||
|
|
||
| sharding_config = shared_config.sharding_config | ||
| sharding_config.validate_config() | ||
| # sharding_config.predefined_config = predefined_config | ||
|
|
||
| if ( | ||
| shared_config.sharding_config.use_sharding_from_factory | ||
| and len(shared_config.sharding_config.get_predefined_config()) > 0 | ||
| sharding_config.use_sharding_from_factory | ||
| and len(sharding_config.get_predefined_config()) > 0 | ||
| ): | ||
| ad_logger.info("Applying sharding from config") | ||
| factory_info = detect_sharding_from_factory_config(gm, sharding_config) | ||
| return gm, factory_info | ||
|
|
||
| ad_logger.info( | ||
| f"Running autodeploy sharding heuristics: {shared_config.sharding_config.sharding_dims}" | ||
| ) | ||
| ad_logger.info(f"Running autodeploy sharding heuristics: {sharding_config.sharding_dims}") | ||
| # run TP sharding across ranks | ||
| if "tp" in shared_config.sharding_config.sharding_dims: | ||
| if "tp" in sharding_config.sharding_dims: | ||
| tp_info = detect_column_row_shard(gm, sharding_config) | ||
| else: | ||
| tp_info = TransformInfo( | ||
| skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True | ||
| ) | ||
|
|
||
| # run EP sharding across ranks | ||
| if "ep" in shared_config.sharding_config.sharding_dims: | ||
| if "ep" in sharding_config.sharding_dims: | ||
| ep_info = detect_ep_shard(gm, sharding_config) | ||
| else: | ||
| ep_info = TransformInfo( | ||
| skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True | ||
| ) | ||
|
|
||
| # run BMM sharding across ranks | ||
| if "bmm" in shared_config.sharding_config.sharding_dims: | ||
| if "bmm" in sharding_config.sharding_dims: | ||
| dp_bmm_info = detect_dp_bmm_shard(gm, sharding_config) | ||
| else: | ||
| dp_bmm_info = TransformInfo( | ||
|
|
@@ -260,6 +256,7 @@ def detect_sharding_from_factory_config( | |
| # 4. the allowed values are: | ||
| # - "colwise" | ||
| # - "rowwise" | ||
| # - "mamba" | ||
| # - "sequence_parallel" | ||
| # - "local_colwise" | ||
| # - "local_rowwise" | ||
|
|
@@ -313,6 +310,24 @@ def detect_sharding_from_factory_config( | |
| num_shards += 1 | ||
| # we have a match. Get the config for this layer | ||
| config = tp_plan[key] | ||
| # check if config has parameters. | ||
| if "(" in config: | ||
| config, params_str = config.split("(", 1) | ||
| params_str = params_str.rsplit(")", 1)[0] # Remove trailing ) | ||
|
|
||
| try: | ||
| # Convert "key" = value to "key": value format for dict parsing | ||
| params_str = params_str.replace(" = ", ": ") | ||
| # Wrap in braces to make it a dict and parse | ||
| config_params = ast.literal_eval("{" + params_str + "}") | ||
| except Exception as e: | ||
| ad_logger.warning( | ||
| f"Failed to parse config params: {params_str}, error: {e}. " | ||
| "Using empty config." | ||
| ) | ||
| config_params = {} | ||
| else: | ||
| config_params = {} | ||
greg-kwasniewski1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if config == "colwise": | ||
| sharding_config.tp_transforms.append( | ||
| TPShardingInfo.from_node( | ||
|
|
@@ -324,6 +339,7 @@ def detect_sharding_from_factory_config( | |
| min_local_shape=min_local_shape, | ||
| ) | ||
| ) | ||
| num_row_col_shards += 1 | ||
| elif config == "rowwise": | ||
| sharding_config.tp_transforms.append( | ||
| TPShardingInfo.from_node( | ||
|
|
@@ -336,6 +352,20 @@ def detect_sharding_from_factory_config( | |
| ) | ||
| ) | ||
| num_row_col_shards += 1 | ||
| elif config == "mamba": | ||
| sharding_config.tp_transforms.append( | ||
| TPShardingInfo.from_node( | ||
| lin_node, | ||
| split_dim=SplitDimension.COLUMN, | ||
| rank=rank, | ||
| world_size=world_size, | ||
| dist_op=None, | ||
| min_local_shape=min_local_shape, | ||
| layer_type=LayerType.MAMBA, | ||
| fused_weight_dims=config_params.get("fused_weight_dims"), | ||
| ) | ||
| ) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my mind, a complete specification for col-row shard of a Mamba-like layer has 4 entries in the
More specifically, I don't think we should add special handling for mamba via something like the Now there is two ways we can get these four entries:
In terms of getting this merged to
In the meantime, the current branch can be used to test+benchmark Nemotron nano-v3. So please make sure it remains available for testing |
||
| num_row_col_shards += 1 | ||
| elif "sequence" in config: | ||
| # TODO: Sequence parallelism is not supported yet. | ||
| ad_logger.warning("Sequence parallelism is not supported yet. Skipping.") | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.