diff --git a/cpp/tensorrt_llm/thop/allreduceOp.cpp b/cpp/tensorrt_llm/thop/allreduceOp.cpp index b1d5aee28ac..21018e241da 100644 --- a/cpp/tensorrt_llm/thop/allreduceOp.cpp +++ b/cpp/tensorrt_llm/thop/allreduceOp.cpp @@ -959,6 +959,13 @@ class AllreduceOp // MIN_LATENCY. if (mStrategy != AllReduceStrategyType::AUTO) { + // Check TWOSHOT constraint: seq_len >= tp_size + if (mStrategy == AllReduceStrategyType::TWOSHOT && seq_len < mGroup.size()) + { + TLLM_LOG_WARNING("TWOSHOT strategy requires seq_len >= tp_size (%zu < %zu), falling back to ONESHOT", + seq_len, mGroup.size()); + return AllReduceStrategyType::ONESHOT; + } return mStrategy; } else diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 55416141e73..cf2ac0abc3e 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -79,6 +79,7 @@ transforms: sharding_source: ['factory','heuristic'] support_partial_config: true sharding_dims: ['tp', 'ep', 'bmm'] + allreduce_strategy: 'AUTO' requires_shape_prop: true sharding_transform_executor: stage: sharding diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/dist.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/dist.py index d6f13fbedd7..2ff656891d9 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/dist.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/dist.py @@ -26,19 +26,23 @@ def all_gather_fake(tensor, dim=0): @torch.library.custom_op("auto_deploy::torch_dist_all_reduce", mutates_args=(), device_types="cuda") -def all_reduce(t: torch.Tensor) -> torch.Tensor: - """All_reduce across the ranks. Reduction op is SUM. +def all_reduce(t: torch.Tensor, strategy: str) -> torch.Tensor: + """All_reduce across the ranks. Reduction op is SUM. Strategy is MANDATORY. + + Args: + t: Tensor to reduce across ranks + strategy: AllReduce strategy - "AUTO", "NCCL", "ONESHOT", "TWOSHOT", "MIN_LATENCY", etc. NOTE: this op requires an extra memory copy and should ONLY be used for debugging + testing. For efficient all_reduce ops one should write/replace it with a fused op. """ if trtllm_dist.is_trtllm_op_available(): - return trtllm_dist.trtllm_allreduce(t, op=dist.ReduceOp.SUM) + return trtllm_dist.trtllm_allreduce(t, op=dist.ReduceOp.SUM, strategy=strategy) t_res = t.clone() dist.all_reduce(t_res, op=dist.ReduceOp.SUM) return t_res @all_reduce.register_fake -def all_reduce_fake(tensor): +def all_reduce_fake(tensor, strategy): return torch.empty_like(tensor) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/linear.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/linear.py index fda48e4ba57..8ecd4215017 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/linear.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/linear.py @@ -34,16 +34,16 @@ def simple_fake(input, weight, bias): "auto_deploy::trtllm_dist_fused_linear_all_reduce", mutates_args=(), device_types="cuda" ) def fused_linear_all_reduce( - input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] + input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], strategy: str ) -> torch.Tensor: - """Fused linear followed by all_reduce on the output.""" + """Fused linear followed by all_reduce on the output. Strategy is MANDATORY.""" output = torch.ops.aten.linear(input, weight, bias) if trtllm_dist.is_trtllm_op_available(): - return trtllm_dist.trtllm_allreduce(output, op=dist.ReduceOp.SUM) + return trtllm_dist.trtllm_allreduce(output, op=dist.ReduceOp.SUM, strategy=strategy) dist.all_reduce(output, op=dist.ReduceOp.SUM) return output @fused_linear_all_reduce.register_fake -def fused_linear_all_reduce_fake(input, weight, bias): +def fused_linear_all_reduce_fake(input, weight, bias, strategy): return torch.ops.aten.linear(input, weight, bias) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py index 90ea04db862..d892cf6417b 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py @@ -245,6 +245,7 @@ def fp8_linear_fake( def fused_fp8_linear_all_reduce( input: torch.Tensor, weight_fp8: torch.Tensor, + strategy: str, bias: Optional[torch.Tensor] = None, input_scale: Optional[torch.Tensor] = None, weight_scale: Optional[torch.Tensor] = None, @@ -253,7 +254,7 @@ def fused_fp8_linear_all_reduce( input, weight_fp8, bias, input_scale, weight_scale ) if trtllm_dist.is_trtllm_op_available(): - return trtllm_dist.trtllm_allreduce(out, op=dist.ReduceOp.SUM) + return trtllm_dist.trtllm_allreduce(out, op=dist.ReduceOp.SUM, strategy=strategy) dist.all_reduce(out, op=dist.ReduceOp.SUM) return out @@ -262,6 +263,7 @@ def fused_fp8_linear_all_reduce( def fused_fp8_linear_all_reduce_fake( input: torch.Tensor, weight_fp8: torch.Tensor, + strategy: str, bias: Optional[torch.Tensor] = None, input_scale: Optional[torch.Tensor] = None, weight_scale: Optional[torch.Tensor] = None, diff --git a/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py b/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py index 434cc1693eb..386083f7cc6 100644 --- a/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py +++ b/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py @@ -18,17 +18,26 @@ def trtllm_allgather(tensor, dim, sizes=None): p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank) return allgather(tensor, p_config, dim=dim, sizes=sizes) - def trtllm_allreduce(tensor, op, all_reduce_params=None): + def trtllm_allreduce(tensor, op, strategy: str, all_reduce_params=None): rank, world_size = get_rank_world_size() assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op." - # Cache key includes rank, world_size, and dtype to handle different configurations - cache_key = (rank, world_size, tensor.dtype) + # Convert string strategy to enum + try: + strategy_enum = getattr(AllReduceStrategy, strategy) + except AttributeError: + raise ValueError( + f"Invalid allreduce strategy: {strategy}. " + f"Valid options: AUTO, NCCL, ONESHOT, TWOSHOT, MIN_LATENCY, " + f"LOWPRECISION, UB, MNNVL, NCCL_SYMMETRIC" + ) + + # Cache key includes rank, world_size, dtype, and strategy to handle different configurations + cache_key = (rank, world_size, tensor.dtype, strategy_enum) if cache_key not in _allreduce_cache: p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank) - # Use Strategy.AUTO for optimal performance _allreduce_cache[cache_key] = AllReduce( - mapping=p_config, strategy=AllReduceStrategy.NCCL, dtype=tensor.dtype + mapping=p_config, strategy=strategy_enum, dtype=tensor.dtype ) torch_op = _allreduce_cache[cache_key] @@ -38,7 +47,11 @@ def trtllm_allreduce(tensor, op, all_reduce_params=None): "dist::fused_allreduce_residual_rmsnorm", mutates_args=(), device_types="cuda" ) def fused_allreduce_residual_rmsnorm( - tensor: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, eps: float + tensor: torch.Tensor, + residual: torch.Tensor, + norm_weight: torch.Tensor, + eps: float, + strategy: str = "AUTO", ) -> tuple[torch.Tensor, torch.Tensor]: """Fusing allreduce, residual (add), and hf_rms_norm together. @@ -54,7 +67,9 @@ def fused_allreduce_residual_rmsnorm( norm_weight=norm_weight, eps=eps, ) - return trtllm_allreduce(tensor, ReduceOp.SUM, all_reduce_params=all_reduce_params) + return trtllm_allreduce( + tensor, ReduceOp.SUM, strategy=strategy, all_reduce_params=all_reduce_params + ) else: # Fallback: unfused implementation using torch distributed # This is used in demollm mode without MPI @@ -79,7 +94,11 @@ def fused_allreduce_residual_rmsnorm( @fused_allreduce_residual_rmsnorm.register_fake def fused_allreduce_residual_rmsnorm_fake( - tensor: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, eps: float + tensor: torch.Tensor, + residual: torch.Tensor, + norm_weight: torch.Tensor, + eps: float, + strategy: str = "AUTO", ) -> tuple[torch.Tensor, torch.Tensor]: return torch.empty_like(tensor), torch.empty_like(tensor) @@ -89,7 +108,7 @@ def fused_allreduce_residual_rmsnorm_fake( def trtllm_allgather(tensor, dim, sizes=None): raise ImportError("TRT-LLM is not available.") - def trtllm_allreduce(tensor, op): + def trtllm_allreduce(tensor, op, strategy: str, all_reduce_params=None): raise ImportError("TRT-LLM is not available.") TRTLLM_OP_AVAILABLE = False diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py b/tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py index d0ebcd0eec8..97d4e4bbbeb 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py @@ -28,7 +28,7 @@ def _allreduce_residual_rmsnorm_pattern( """ input_dtype = x.dtype - hidden_states = torch.ops.auto_deploy.torch_dist_all_reduce(x) + hidden_states = torch.ops.auto_deploy.torch_dist_all_reduce(x, "AUTO") add = residual + hidden_states hidden_states = add.to(torch.float32) @@ -52,7 +52,7 @@ def _allreduce_residual_rmsnorm_pattern2( """ input_dtype = x.dtype - hidden_states = torch.ops.auto_deploy.torch_dist_all_reduce(x) + hidden_states = torch.ops.auto_deploy.torch_dist_all_reduce(x, "AUTO") add = hidden_states + residual hidden_states = add.to(torch.float32) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 1bb99974ac1..ff1cf68a133 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -22,9 +22,10 @@ from typing import DefaultDict, Dict, List, Set, Tuple, Type import torch -from pydantic import Field +from pydantic import Field, field_validator from torch.fx import GraphModule, Node +from .....functional import AllReduceStrategy from ...models.factory import ModelFactory, ShardingConfigSource from ...shim.interface import CachedSequenceInterface from ...utils.logger import ad_logger @@ -49,6 +50,7 @@ SplitDimension, WeightShardingInfo, get_all_weights_in_subgraph, + validate_allreduce_strategy, ) from ..interface import ( BaseTransform, @@ -152,6 +154,18 @@ class ShardingTransformConfig(TransformConfig): sharding_dims: List[ShardingDim] = Field( default_factory=lambda: [ShardingDim.SSM, ShardingDim.TP, ShardingDim.EP, ShardingDim.BMM] ) + allreduce_strategy: AllReduceStrategy = Field( + default=AllReduceStrategy.AUTO, + description="AllReduce strategy for distributed operations. " + "Options: AUTO (automatic selection), NCCL, ONESHOT, TWOSHOT, MIN_LATENCY, " + "LOWPRECISION, UB, MNNVL, NCCL_SYMMETRIC", + ) + + @field_validator("allreduce_strategy", mode="before") + @classmethod + def _validate_allreduce_strategy(cls, v): + """Convert string names like 'AUTO' to AllReduceStrategy enum.""" + return validate_allreduce_strategy(v) @TransformRegistry.register("detect_sharding") @@ -199,6 +213,8 @@ def _apply( sharding_config = shared_config.sharding_config sharding_config.rank = local_rank sharding_config.world_size = world_size + sharding_config.allreduce_strategy = self.config.allreduce_strategy + ad_logger.info(f"Using allreduce strategy: {sharding_config.allreduce_strategy.name}") sharding_config.predefined_config = factory.get_sharding_config() if factory else {} sharding_config.factory_source = ( sharding_config.predefined_config.get("source", ShardingConfigSource.UNKNOWN) @@ -573,7 +589,7 @@ def detect_sharding_from_factory_config( # we have a match. Get the config for this layer config = tp_plan[key] if config == "colwise": - sharding_config.weight_sharding_transforms.append( + if sharding_config.add( WeightShardingInfo.from_node( lin_node, split_dim=SplitDimension.COLUMN, @@ -582,10 +598,10 @@ def detect_sharding_from_factory_config( dist_op=None, min_local_shape=min_local_shape, ) - ) - num_row_col_shards += 1 + ): + num_row_col_shards += 1 elif config == "rowwise": - sharding_config.weight_sharding_transforms.append( + if sharding_config.add( WeightShardingInfo.from_node( lin_node, split_dim=SplitDimension.ROW, @@ -594,10 +610,10 @@ def detect_sharding_from_factory_config( dist_op="all_reduce", min_local_shape=min_local_shape, ) - ) - num_row_col_shards += 1 + ): + num_row_col_shards += 1 elif config == "mamba": - sharding_config.weight_sharding_transforms.append( + sharding_config.add( WeightShardingInfo.from_node( lin_node, split_dim=SplitDimension.COLUMN, @@ -618,7 +634,7 @@ def detect_sharding_from_factory_config( if "shared" in module_name: col_row_action = config.replace("local_", "") if col_row_action == "colwise": - sharding_config.weight_sharding_transforms.append( + sharding_config.add( WeightShardingInfo( target_node=lin_node.name, split_dim=SplitDimension.COLUMN, @@ -629,7 +645,7 @@ def detect_sharding_from_factory_config( ) ) elif col_row_action == "rowwise": - sharding_config.weight_sharding_transforms.append( + if sharding_config.add( WeightShardingInfo( target_node=lin_node.name, split_dim=SplitDimension.ROW, @@ -638,8 +654,8 @@ def detect_sharding_from_factory_config( dist_op="all_reduce", min_local_shape=min_local_shape, ) - ) - num_row_col_shards += 1 + ): + num_row_col_shards += 1 else: ad_logger.warning(f"Unsupported sharding action {config}. Skipping.") else: @@ -648,7 +664,7 @@ def detect_sharding_from_factory_config( elif "gather" in config: # Simple shard (row + all_gather) - sharding_config.weight_sharding_transforms.append( + if sharding_config.add( WeightShardingInfo.from_node( lin_node, split_dim=SplitDimension.COLUMN, @@ -657,13 +673,13 @@ def detect_sharding_from_factory_config( dist_op="all_gather", min_local_shape=1, ) - ) - num_simple_shards += 1 + ): + num_simple_shards += 1 else: ad_logger.warning( f"Unsupported sharding action {config}. Fallback to simple shard" ) - sharding_config.weight_sharding_transforms.append( + sharding_config.add( WeightShardingInfo.from_node( lin_node, split_dim=SplitDimension.COLUMN, @@ -943,7 +959,7 @@ def detect_column_row_shard( ) # shard single row node - sharding_config.weight_sharding_transforms.append( + if sharding_config.add( WeightShardingInfo.from_node( nodes_to_row_shard[0], split_dim=SplitDimension.ROW, @@ -952,9 +968,8 @@ def detect_column_row_shard( dist_op="all_reduce", min_local_shape=min_local_shape, ) - ) - - num_row_col_shards += 1 + ): + num_row_col_shards += 1 ad_logger.info( f"Found {num_shards} TP shards (simple: {num_simple_shards}, row-col: {num_row_col_shards})" @@ -1020,7 +1035,7 @@ def detect_dp_bmm_shard(gm: GraphModule, sharding_config: ShardingConfig) -> Tra start_idx = remainder + rank * base_size end_idx = start_idx + base_size - sharding_config.bmm_transforms.append( + sharding_config.add( BMMShardingInfo( target_node=node.name, rank=rank, @@ -1064,14 +1079,14 @@ def detect_ep_shard(gm: GraphModule, sharding_config: ShardingConfig) -> Transfo ), ): continue - sharding_config.ep_transforms.append( + if sharding_config.add( EPShardingInfo.from_node( node, rank=rank, world_size=world_size, ) - ) - num_moe_patterns += 1 + ): + num_moe_patterns += 1 ad_logger.info(f"Found {num_moe_patterns} MoE patterns") diff --git a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py index 736318d355a..c4f9e6a0635 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py @@ -10,9 +10,10 @@ import torch import torch.nn as nn -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from torch.fx import GraphModule, Node +from ....functional import AllReduceStrategy from ..models.factory import ShardingConfigSource from ..utils.logger import ad_logger from .node_utils import ( @@ -29,6 +30,36 @@ ) +def validate_allreduce_strategy(v): + """Convert string names like 'AUTO' to AllReduceStrategy enum. + + This is a shared validator for allreduce_strategy fields across all config classes. + + Args: + v: Value to validate - can be AllReduceStrategy enum, string name, or integer value + + Returns: + AllReduceStrategy enum value + + Raises: + ValueError: If the input is an invalid strategy string + """ + if isinstance(v, AllReduceStrategy): + return v + if isinstance(v, str): + # Try to get enum by name + try: + return AllReduceStrategy[v] + except KeyError: + raise ValueError( + f"Invalid allreduce strategy: {v}. " + f"Valid options: {', '.join(s.name for s in AllReduceStrategy)}" + ) + if isinstance(v, int): + return AllReduceStrategy(v) + return v # Let Pydantic handle other types + + def _load_hook( state_dict, prefix, @@ -230,6 +261,7 @@ def _insert_sharded_mamba( dim: int, rank: int, world_size: int, + allreduce_strategy: AllReduceStrategy, add_dist: bool = False, min_local_shape: int = 1, weights_to_shard: Optional[list[str]] = None, @@ -241,6 +273,8 @@ def _insert_sharded_mamba( ) -> bool: """ To shard Mamba layer, first column-shard the first linear layer: entry_node, + + NOTE: allreduce_strategy is MANDATORY and must be explicitly provided. then shard all remaining weight tensors found in the subgraph defined between entry_node and the next successor linear node. First, validate if this is indeed a mamba module: within the subgraph, @@ -259,6 +293,10 @@ def _insert_sharded_mamba( fused_weight_dims: Optional dict mapping weight keys to their fused dimension lists quantization_cb: Optional quantization callback """ + if allreduce_strategy is None: + raise ValueError( + f"allreduce_strategy must be set for Mamba sharding on node {entry_node.name}" + ) # Find next linear node to define subgraph boundary try: next_lin_node, depth = bfs(entry_node, is_any_lin_op, include_root=False) @@ -342,6 +380,7 @@ def _insert_sharded_mamba( min_local_shape=min_local_shape, fused_weight_dims=entry_fused_dims, quantization_cb=quantization_cb, + allreduce_strategy=allreduce_strategy, ) # Get all weight nodes in the subgraph except for out_proj @@ -401,6 +440,7 @@ def _shard_parameter_node( dim: int, rank: int, world_size: int, + allreduce_strategy: AllReduceStrategy, add_dist: bool = False, min_local_shape: int = 1, fused_weight_dims: Optional[list] = None, @@ -410,8 +450,14 @@ def _shard_parameter_node( ) -> None: """Replace the node with parametrized weight tensor with a new node that accepts sharded weights. + NOTE: allreduce_strategy is MANDATORY and must be explicitly provided. + The state_dict is also updated to contain the sharded weights. """ + if allreduce_strategy is None: + raise ValueError( + f"allreduce_strategy must be set for parameter sharding on node {node.name}" + ) assert dim in [0, 1], "Only dim 0 and 1 are supported for sharding" assert add_dist or dim == 0, "For dim=1 sharding, dist_op is required." @@ -486,15 +532,18 @@ def _shard_parameter_node( return # figure out the right dist op - dist_lookup = { - 0: (torch.ops.auto_deploy.torch_dist_all_gather.default, -1), - 1: (torch.ops.auto_deploy.torch_dist_all_reduce.default,), - } - fn_dist, *dist_args = dist_lookup[dim] + if dim == 0: + # Column split -> all_gather + fn_dist = torch.ops.auto_deploy.torch_dist_all_gather.default + dist_args = (node, -1) + else: + # Row split -> all_reduce with strategy + fn_dist = torch.ops.auto_deploy.torch_dist_all_reduce.default + dist_args = (node, allreduce_strategy.name) # add reduction node with gm.graph.inserting_after(node): - dist_node = gm.graph.call_function(fn_dist, args=(node, *dist_args)) + dist_node = gm.graph.call_function(fn_dist, args=dist_args) node.replace_all_uses_with(dist_node) dist_node.replace_input_with(dist_node, node) @@ -565,7 +614,11 @@ class LayerType(Enum): class WeightShardingInfo(ShardingTransformInfo): - """Configuration for TP sharding transformations.""" + """Configuration for TP sharding transformations. + + NOTE: allreduce_strategy will be automatically injected by ShardingConfig.add() + if not provided at creation time. The strategy comes from the parent ShardingConfig. + """ split_dim: SplitDimension dist_op: Optional[Literal["all_reduce", "all_gather"]] = None @@ -573,6 +626,7 @@ class WeightShardingInfo(ShardingTransformInfo): layer_type: LayerType = LayerType.MLP # used for TP sharding of fused weights fused_weight_dims: Optional[list] = None + allreduce_strategy: Optional[AllReduceStrategy] = None # Set by ShardingConfig.add() if None def quantization_cb( self, @@ -628,6 +682,7 @@ def apply(self, gm: GraphModule, node: Node) -> None: if isinstance(self.fused_weight_dims, dict) else None, quantization_cb=self.quantization_cb, + allreduce_strategy=self.allreduce_strategy, ) else: _shard_parameter_node( @@ -640,15 +695,13 @@ def apply(self, gm: GraphModule, node: Node) -> None: min_local_shape=self.min_local_shape, fused_weight_dims=self.fused_weight_dims, quantization_cb=self.quantization_cb, + allreduce_strategy=self.allreduce_strategy, ) class ParameterUpdateInfo(ShardingTransformInfo): """Configuration for node args sharding transformations.""" - target_node: str - rank: int - world_size: int args: tuple def validate(self, gm: GraphModule = None, node: Node = None) -> bool: @@ -938,12 +991,17 @@ def _insert_sharded_moe( node: Node, rank: int, world_size: int, + allreduce_strategy: AllReduceStrategy, scale_names: Sequence[str] = (), ): """Update the torch_moe node with sharded weight lists, sharded `selected_experts` and `final_scales(router_logics)`. Add an all_reduce node after the moe node. + + NOTE: allreduce_strategy is MANDATORY. """ + if allreduce_strategy is None: + raise ValueError(f"allreduce_strategy must be set for MoE sharding on node {node.name}") scale_names = list(scale_names) num_experts = len(node.args[3]) @@ -1014,7 +1072,8 @@ def get_partition(lst, world_size, rank): # -- add an all_reduce node -- with gm.graph.inserting_after(node): dist_node = gm.graph.call_function( - torch.ops.auto_deploy.torch_dist_all_reduce.default, args=(node,) + torch.ops.auto_deploy.torch_dist_all_reduce.default, + args=(node, allreduce_strategy.name), ) node.replace_all_uses_with(dist_node) dist_node.replace_input_with(dist_node, node) @@ -1043,13 +1102,15 @@ def _insert_sharded_mxfp4_mlp_ep( node: Node, rank: int, world_size: int, + allreduce_strategy: AllReduceStrategy, ): - """ - Transform a call to auto_deploy::triton_mxfp4_moe into: - - sharded expert parameters along dim 0 (this rank's slice), + """Transform a call to auto_deploy::triton_mxfp4_moe into: + - sharded expert parameters along dim 0 (this rank slice), - call to auto_deploy::triton_mxfp4_moe_ep(..., local_lo, local_hi), - followed by torch_dist_all_reduce. + NOTE: allreduce_strategy is MANDATORY and must be explicitly provided. + Expects the original op signature: (hidden_states, router_weight, router_bias, top_k, @@ -1057,6 +1118,10 @@ def _insert_sharded_mxfp4_mlp_ep( alpha, limit, down_blocks, down_bias, down_scales) """ + if allreduce_strategy is None: + raise ValueError( + f"allreduce_strategy must be set for MXFP4 MLP EP sharding on node {node.name}" + ) IDX_GATE_UP_BLOCKS = 4 IDX_GATE_UP_BIAS = 5 @@ -1085,17 +1150,22 @@ def _insert_sharded_mxfp4_mlp_ep( # Add a dist all-reduce after the op (sum partial results across EP ranks) with gm.graph.inserting_after(node): - red = gm.graph.call_function(torch.ops.auto_deploy.torch_dist_all_reduce, args=(node,)) + red = gm.graph.call_function( + torch.ops.auto_deploy.torch_dist_all_reduce, args=(node, allreduce_strategy.name) + ) node.replace_all_uses_with(red) # keep dataflow: red(input=node) red.replace_input_with(red, node) class EPShardingInfo(ShardingTransformInfo): - """Configuration for EP sharding transformations.""" + """Configuration for EP sharding transformations. - rank: int - world_size: int + NOTE: allreduce_strategy will be automatically injected by ShardingConfig.add() + if not provided at creation time. The strategy comes from the parent ShardingConfig. + """ + + allreduce_strategy: Optional[AllReduceStrategy] = None # Set by ShardingConfig.add() if None @classmethod def from_node(cls, node: Node, **kwargs) -> "EPShardingInfo": @@ -1114,7 +1184,7 @@ def validate(self, gm: GraphModule = None, node: Node = None) -> bool: def apply(self, gm: GraphModule, node: Node) -> None: """Apply EP sharding transformation to the graph module.""" - _insert_sharded_moe(gm, node, self.rank, self.world_size, []) + _insert_sharded_moe(gm, node, self.rank, self.world_size, self.allreduce_strategy, []) class MXFP4EPShardingInfo(EPShardingInfo): @@ -1128,7 +1198,7 @@ def validate(self, gm: GraphModule = None, node: Node = None) -> bool: return True def apply(self, gm: GraphModule, node: Node) -> None: - _insert_sharded_mxfp4_mlp_ep(gm, node, self.rank, self.world_size) + _insert_sharded_mxfp4_mlp_ep(gm, node, self.rank, self.world_size, self.allreduce_strategy) class FP8EPShardingInfo(EPShardingInfo, QuantizationShardingMixin): @@ -1144,7 +1214,9 @@ def scale_names(self) -> List[str]: return ["input_scale", "weight_scale"] def apply(self, gm: GraphModule, node: Node) -> None: - _insert_sharded_moe(gm, node, self.rank, self.world_size, self.scale_names()) + _insert_sharded_moe( + gm, node, self.rank, self.world_size, self.allreduce_strategy, self.scale_names() + ) class NVFP4EPShardingInfo(EPShardingInfo, QuantizationShardingMixin): @@ -1160,7 +1232,9 @@ def scale_names(self) -> List[str]: return ["input_scale", "weight_scale", "alpha"] def apply(self, gm: GraphModule, node: Node) -> None: - _insert_sharded_moe(gm, node, self.rank, self.world_size, self.scale_names()) + _insert_sharded_moe( + gm, node, self.rank, self.world_size, self.allreduce_strategy, self.scale_names() + ) EP_SHARDING_RULES = [ @@ -1213,11 +1287,22 @@ class ShardingConfig(BaseModel): sharding_dims: List[ShardingDim] = Field( default_factory=lambda: [ShardingDim.SSM, ShardingDim.TP, ShardingDim.EP, ShardingDim.BMM] ) + allreduce_strategy: AllReduceStrategy = Field( + default=AllReduceStrategy.AUTO, + description="AllReduce strategy for distributed operations. " + "Options: AUTO, NCCL, ONESHOT, TWOSHOT, MIN_LATENCY, LOWPRECISION, UB, MNNVL, NCCL_SYMMETRIC, SYMM_MEM", + ) weight_sharding_transforms: List[WeightShardingInfo] = Field(default_factory=list) parameter_update_transforms: List[ParameterUpdateInfo] = Field(default_factory=list) bmm_transforms: List[BMMShardingInfo] = Field(default_factory=list) ep_transforms: List[EPShardingInfo] = Field(default_factory=list) + @field_validator("allreduce_strategy", mode="before") + @classmethod + def _validate_allreduce_strategy(cls, v): + """Convert string names like 'AUTO' to AllReduceStrategy enum.""" + return validate_allreduce_strategy(v) + def __init__(self, **kwargs): super().__init__(**kwargs) self._transform_list_dict = { @@ -1240,7 +1325,18 @@ def _validate_and_normalize(self): def add(self, transform: ShardingTransformInfo) -> bool: """Append a transform only if that node was not sharded before. Do not overwrite existing transforms. + + Automatically propagates allreduce_strategy from this config to the transform + if the transform doesn't already have one set. """ + # Inject allreduce_strategy from config into transform if it has the attribute and it's None + # This creates a new transform instance with the strategy set + if hasattr(transform, "allreduce_strategy") and transform.allreduce_strategy is None: + # Create a new transform with the strategy injected + transform_dict = transform.model_dump() + transform_dict["allreduce_strategy"] = self.allreduce_strategy + transform = type(transform)(**transform_dict) + # Find the appropriate list by checking inheritance transform_list = None for base_class, transform_list_candidate in self._transform_list_dict.items(): diff --git a/tensorrt_llm/plugin/plugin.py b/tensorrt_llm/plugin/plugin.py index 5010be4ea9d..60e12e98207 100644 --- a/tensorrt_llm/plugin/plugin.py +++ b/tensorrt_llm/plugin/plugin.py @@ -581,13 +581,44 @@ def set_workspace_tensor(self, @staticmethod def max_workspace_size_auto(tp_size: int, support_deterministic=True) -> int: + """Calculate workspace size for allreduce fusion kernel. + + The workspace is used for lamport buffers in the fusion kernel. + Required size calculation: + - Each GPU needs 3 sub-buffers (for triple buffering) + - Each sub-buffer stores: max_num_tokens * hidden_size * dtype_size (bf16=2) + - The lamport allocation multiplies by tp_size, so: + lamport_size = 3 * size * tp_size (per GPU) + + Example: Llama 8B (hidden=4096), max_tokens=8192, bf16, TP=4 + - Data per sub-buffer: 8192 * 4096 * 2 = 64 MiB + - Total lamport: 3 * 64MB * 4 = 768 MiB per GPU + - Required 'size' parameter: 64 MiB (gets multiplied by tp_size in allocation) + + Default (67,108,864 = 64 MiB) supports: + - Models up to hidden_size=4096 with max_num_tokens=8192 + - Or hidden_size=8192 with max_num_tokens=4096 + + Override with TRTLLM_ALLREDUCE_FUSION_WORKSPACE_SIZE env var if needed for larger models. + """ if force_all_reduce_deterministic() and support_deterministic: workspace_size = os.getenv("FORCE_ALLREDUCE_KERNEL_WORKSPACE_SIZE", "1000000000") return int(workspace_size) - if tp_size <= 2: - return 16_000_000 - return 8_000_000 + + # Allow override via environment variable for edge cases + workspace_size_env = os.getenv("TRTLLM_ALLREDUCE_FUSION_WORKSPACE_SIZE") + if workspace_size_env: + size = int(workspace_size_env) + logger.info( + f"Using custom allreduce fusion workspace size: {size} bytes ({size / (1024**2):.1f} MiB)" + ) + return size + + # Default: 64 MiB - supports most common model configurations + # Increase via env var if you see CUDA illegal memory access errors with large models + default_size = 67_108_864 # Exactly 64 MiB + return default_size @staticmethod def max_workspace_size_lowprecision(tp_size: int) -> int: diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_dist.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_dist.py index d4c8091158a..6a72868e3b5 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_dist.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_dist.py @@ -9,7 +9,7 @@ def _run_all_reduce_test(rank, world_size): x = torch.ones(10, 10).to("cuda") - y = torch.ops.auto_deploy.torch_dist_all_reduce(x) + y = torch.ops.auto_deploy.torch_dist_all_reduce(x, "AUTO") assert torch.equal(x * world_size, y) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/test_ad_allreduce_strategies.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/test_ad_allreduce_strategies.py new file mode 100644 index 00000000000..ad5ea287218 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/test_ad_allreduce_strategies.py @@ -0,0 +1,328 @@ +import signal +import subprocess +import tempfile +from contextlib import contextmanager +from pathlib import Path + +import pytest +import torch +import torch.nn as nn +import yaml +from _model_test_utils import get_small_model_config +from click.testing import CliRunner +from utils.cpp_paths import llm_root # noqa: F401 + +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op +from tensorrt_llm._torch.auto_deploy.utils.sharding_utils import ( + ShardingConfig, + SplitDimension, + WeightShardingInfo, +) +from tensorrt_llm.commands.bench import main +from tensorrt_llm.functional import AllReduceStrategy + + +class TimeoutError(Exception): + """Exception raised when a test times out.""" + + pass + + +@contextmanager +def timeout(seconds): + """Context manager that raises TimeoutError if code block exceeds time limit. + + Args: + seconds: Maximum time in seconds to allow the code block to run + + Raises: + TimeoutError: If the code block execution exceeds the time limit + """ + + def timeout_handler(signum, frame): + raise TimeoutError(f"Test execution exceeded {seconds} seconds timeout") + + # Set the signal handler and alarm + old_handler = signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(seconds) + try: + yield + finally: + # Restore the old signal handler and cancel the alarm + signal.alarm(0) + signal.signal(signal.SIGALRM, old_handler) + + +@pytest.fixture(scope="module") +def shared_dataset(llm_root): # noqa: F811 + """Prepare dataset once for all tests in this module.""" + model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct" + config = get_small_model_config(model_name) + with tempfile.TemporaryDirectory() as temp_dir: + dataset_path = _prepare_dataset( + llm_root, temp_dir, config["args"]["model"], num_requests=10 + ) + # Read dataset content to return it (temp_dir will be deleted) + with open(dataset_path, "r") as f: + dataset_content = f.read() + yield dataset_content + + +def _prepare_dataset(root_dir: str, temp_dir: str, model_path_or_name: str, num_requests: int = 10): + """Prepare a synthetic dataset for benchmarking.""" + _DATASET_NAME = "synthetic_128_128.txt" + dataset_path = Path(temp_dir, _DATASET_NAME) + dataset_tool = Path(root_dir, "benchmarks", "cpp", "prepare_dataset.py") + script_dir = Path(root_dir, "benchmarks", "cpp") + + # Generate a small dataset to run a test - matching workload configuration + command = [ + "python3", + f"{dataset_tool}", + "--stdout", + "--tokenizer", + model_path_or_name, + "token-norm-dist", + "--input-mean", + "128", + "--output-mean", + "128", + "--input-stdev", + "0", + "--output-stdev", + "0", + "--num-requests", + str(num_requests), + ] + print(f"Running command: {' '.join(command)}") + result = subprocess.run( + command, cwd=str(script_dir), capture_output=True, text=True, timeout=300 + ) + if result.returncode != 0: + raise RuntimeError(f"Failed to prepare dataset: {result.stderr}") + # Grab the stdout and write it to a dataset file for passing to suite. + with open(dataset_path, "w") as dataset: + dataset.write(result.stdout) + return dataset_path + + +@pytest.mark.parametrize( + "allreduce_strategy", + [ + "AUTO", + "ONESHOT", + "TWOSHOT", + "MIN_LATENCY", + "NCCL", + ], +) +def test_allreduce_strategies(llm_root, shared_dataset, allreduce_strategy): # noqa: F811 + """Test different allreduce strategies with multi-GPU configuration making sure that there are no crashes or hangs. + + Configuration: + The allreduce_strategy is set in the transforms config: + ```yaml + transforms: + detect_sharding: + allreduce_strategy: "ONESHOT" # or AUTO, NCCL, TWOSHOT, etc. + ``` + + Test configuration: + - Model: Llama-3.1-8B with TP=2 + - Dataset: 10 synthetic requests (128 input, 128 output tokens) + - Timeout: 300 seconds to catch hangs + - Skipped if fewer than 2 GPUs available + + Args: + llm_root: Root directory fixture + shared_dataset: Shared dataset fixture (prepared once for all test runs) + allreduce_strategy: Strategy to test (AUTO, ONESHOT, TWOSHOT, MIN_LATENCY, NCCL) + """ + # Fixed timeout for all strategies (5 minutes should be enough) + TEST_TIMEOUT_SECONDS = 300 + + model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct" + config = get_small_model_config(model_name) + tp_size = 2 + max_batch_size = 256 + max_num_tokens = 8192 + + if not torch.cuda.is_available() or torch.cuda.device_count() < tp_size: + pytest.skip(f"Allreduce strategy test requires at least {tp_size} GPUs, skipping") + + with tempfile.TemporaryDirectory() as temp_dir: + # Write shared dataset to temp location + dataset_path = Path(temp_dir, "synthetic_128_128.txt") + with open(dataset_path, "w") as f: + f.write(shared_dataset) + + # Create configuration with specified allreduce strategy in transforms + extra_llm_api_options_path = f"{temp_dir}/extra_llm_api_options.yaml" + with open(extra_llm_api_options_path, "w") as f: + yaml.dump( + { + **config["args"], + "max_batch_size": max_batch_size, + "max_num_tokens": max_num_tokens, + "max_seq_len": 256, + "transforms": { + "detect_sharding": { + "stage": "sharding", + "allreduce_strategy": allreduce_strategy, + }, + "compile_model": { + "stage": "compile", + "backend": "torch-cudagraph", + "cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128, 256], + }, + }, + }, + f, + ) + + # Run benchmark with specified allreduce strategy with timeout protection + runner = CliRunner() + args = [ + "--model", + model_name, + ] + + # Only pass --model_path if it's a local filesystem path + # Note: --model_path must come BEFORE the subcommand (throughput) + if str(config["args"]["model"]).startswith("/"): + args.extend(["--model_path", str(config["args"]["model"])]) + + # Add the subcommand and its options + args.extend( + [ + "throughput", + "--backend", + "_autodeploy", + "--dataset", + str(dataset_path), + "--extra_llm_api_options", + extra_llm_api_options_path, + "--tp", + str(tp_size), + "--max_batch_size", + str(max_batch_size), + "--max_num_tokens", + str(max_num_tokens), + ] + ) + + try: + with timeout(TEST_TIMEOUT_SECONDS): + result = runner.invoke(main, args, catch_exceptions=False) + assert result.exit_code == 0, f"Benchmark failed with output: {result.output}" + except TimeoutError as e: + pytest.fail( + f"Test timed out after {TEST_TIMEOUT_SECONDS}s for strategy {allreduce_strategy}. " + f"This might indicate a hang (e.g., TWOSHOT without C++ fix). Error: {e}" + ) + + +@pytest.mark.parametrize( + "strategy", + [ + "AUTO", + "NCCL", + "TWOSHOT", + "MIN_LATENCY", + ], +) +def test_allreduce_strategy_propagation(strategy): + """Test that allreduce_strategy is correctly propagated to graph nodes. + + This test verifies that when we set an allreduce_strategy on the ShardingConfig, + it gets properly injected into the transforms and passed to the torch_dist_all_reduce + nodes in the compiled graph. + """ + + # Create a simple MLP model + class SimpleMLP(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(128, 256, bias=False) + self.linear2 = nn.Linear(256, 128, bias=False) + + def forward(self, x): + return self.linear2(torch.relu(self.linear1(x))) + + model = SimpleMLP() + dummy_input = torch.randn(2, 128) + + # Export to graph + gm = torch_export_to_gm(model, (dummy_input,)) + + # Find linear nodes in the graph + from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_linear_op + + linear_nodes = [node for node in gm.graph.nodes if is_linear_op(node)] + assert len(linear_nodes) == 2, f"Expected 2 linear nodes, found {len(linear_nodes)}" + + linear1_node, linear2_node = linear_nodes[0], linear_nodes[1] + + # Create sharding config with specified strategy + rank, world_size = 0, 4 + sharding_config = ShardingConfig( + rank=rank, world_size=world_size, allreduce_strategy=AllReduceStrategy[strategy] + ) + + # Add transforms: column shard linear1, row shard linear2 (triggers allreduce) + sharding_config.add( + WeightShardingInfo( + target_node=linear1_node.name, + rank=rank, + world_size=world_size, + split_dim=SplitDimension.COLUMN, + dist_op=None, + ) + ) + sharding_config.add( + WeightShardingInfo( + target_node=linear2_node.name, + rank=rank, + world_size=world_size, + split_dim=SplitDimension.ROW, + dist_op="all_reduce", + ) + ) + + # Verify transforms have the strategy injected + assert len(sharding_config.weight_sharding_transforms) == 2 + for transform in sharding_config.weight_sharding_transforms: + assert transform.allreduce_strategy == AllReduceStrategy[strategy], ( + f"Transform {transform.target_node} should have strategy {strategy}, got {transform.allreduce_strategy}" + ) + + # Apply transforms + for transform in sharding_config.weight_sharding_transforms: + node = next((n for n in gm.graph.nodes if n.name == transform.target_node), None) + if node: + transform.check_and_apply(gm, node) + + gm.recompile() + + # Verify the graph contains torch_dist_all_reduce nodes with correct strategy + allreduce_nodes = [ + node for node in gm.graph.nodes if is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce) + ] + + # Should have exactly one allreduce node (from linear2 row sharding) + assert len(allreduce_nodes) == 1, f"Expected 1 allreduce node, found {len(allreduce_nodes)}" + + # Verify the allreduce node has the correct strategy argument + allreduce_node = allreduce_nodes[0] + # torch_dist_all_reduce signature: (input, strategy_name) + assert len(allreduce_node.args) == 2, ( + f"Expected 2 args for allreduce node, got {len(allreduce_node.args)}" + ) + + strategy_arg = allreduce_node.args[1] + assert strategy_arg == strategy, ( + f"Expected allreduce strategy '{strategy}', got '{strategy_arg}'" + ) + + print(f"✓ Test passed: allreduce_strategy '{strategy}' correctly propagated to graph node") diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py index 797e9f94cec..83bccc33865 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py @@ -37,7 +37,7 @@ def __init__(self, hidden_size, dtype): self.norm = RMSNorm(hidden_size, 1e-5, dtype) def forward(self, x, residual): - x = torch.ops.auto_deploy.torch_dist_all_reduce.default(x) + x = torch.ops.auto_deploy.torch_dist_all_reduce.default(x, "AUTO") y = residual + x normed = self.norm(y) return normed, y @@ -51,7 +51,7 @@ def __init__(self, hidden_size, dtype): self.norm = RMSNorm(hidden_size, 1e-5, dtype) def forward(self, x, residual): - x = torch.ops.auto_deploy.torch_dist_all_reduce.default(x) + x = torch.ops.auto_deploy.torch_dist_all_reduce.default(x, "AUTO") y = x + residual normed = self.norm(y) return normed, y diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py index 94e236cd4e4..cec4d778eda 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py @@ -17,6 +17,7 @@ FP8EPShardingInfo, NVFP4EPShardingInfo, ) +from tensorrt_llm.functional import AllReduceStrategy def _run_ep_shard_job(num_experts: int, rank: int, world_size: int) -> None: @@ -93,6 +94,7 @@ def _run_pattern_detection_job(num_experts: int, rank: int, world_size: int) -> target_node=node.name, rank=rank, world_size=world_size, + allreduce_strategy=AllReduceStrategy.AUTO, ) ) elif is_op(node, torch.ops.auto_deploy.torch_quant_fp8_moe): @@ -101,6 +103,7 @@ def _run_pattern_detection_job(num_experts: int, rank: int, world_size: int) -> target_node=node.name, rank=rank, world_size=world_size, + allreduce_strategy=AllReduceStrategy.AUTO, ) ) elif is_op(node, torch.ops.auto_deploy.torch_quant_nvfp4_moe): @@ -109,6 +112,7 @@ def _run_pattern_detection_job(num_experts: int, rank: int, world_size: int) -> target_node=node.name, rank=rank, world_size=world_size, + allreduce_strategy=AllReduceStrategy.AUTO, ) ) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py index 58855fb0318..0428ff08a60 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py @@ -20,6 +20,7 @@ from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_linear_op, is_op from tensorrt_llm._torch.auto_deploy.utils.sharding_utils import FP8TPShardingInfo +from tensorrt_llm.functional import AllReduceStrategy base_model_tp_plan = { "q_proj": "colwise", @@ -279,6 +280,7 @@ def _run_pattern_detection_job( world_size=world_size, dist_op=dist_op, min_local_shape=min_local_shape, + allreduce_strategy=AllReduceStrategy.AUTO, ) ) elif model_cls == MLP: @@ -300,6 +302,7 @@ def _run_pattern_detection_job( world_size=world_size, dist_op=dist_op, min_local_shape=1, + allreduce_strategy=AllReduceStrategy.AUTO, ) ) elif model_cls == nn.Linear: @@ -314,6 +317,7 @@ def _run_pattern_detection_job( world_size=world_size, dist_op="all_gather", min_local_shape=1, + allreduce_strategy=AllReduceStrategy.AUTO, ) ) elif model_cls == FP8MLP: @@ -335,6 +339,7 @@ def _run_pattern_detection_job( world_size=world_size, dist_op=dist_op, min_local_shape=1, + allreduce_strategy=AllReduceStrategy.AUTO, ) )