Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cpp/tensorrt_llm/thop/allreduceOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,13 @@ class AllreduceOp
// MIN_LATENCY.
if (mStrategy != AllReduceStrategyType::AUTO)
{
// Check TWOSHOT constraint: seq_len >= tp_size
if (mStrategy == AllReduceStrategyType::TWOSHOT && seq_len < mGroup.size())
{
TLLM_LOG_WARNING("TWOSHOT strategy requires seq_len >= tp_size (%zu < %zu), falling back to ONESHOT",
seq_len, mGroup.size());
return AllReduceStrategyType::ONESHOT;
}
return mStrategy;
}
else
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ transforms:
sharding_source: ['factory','heuristic']
support_partial_config: true
sharding_dims: ['tp', 'ep', 'bmm']
allreduce_strategy: 'AUTO'
requires_shape_prop: true
sharding_transform_executor:
stage: sharding
Expand Down
12 changes: 8 additions & 4 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,23 @@ def all_gather_fake(tensor, dim=0):


@torch.library.custom_op("auto_deploy::torch_dist_all_reduce", mutates_args=(), device_types="cuda")
def all_reduce(t: torch.Tensor) -> torch.Tensor:
"""All_reduce across the ranks. Reduction op is SUM.
def all_reduce(t: torch.Tensor, strategy: str) -> torch.Tensor:
"""All_reduce across the ranks. Reduction op is SUM. Strategy is MANDATORY.

Args:
t: Tensor to reduce across ranks
strategy: AllReduce strategy - "AUTO", "NCCL", "ONESHOT", "TWOSHOT", "MIN_LATENCY", etc.

NOTE: this op requires an extra memory copy and should ONLY be used for debugging + testing. For
efficient all_reduce ops one should write/replace it with a fused op.
"""
if trtllm_dist.is_trtllm_op_available():
return trtllm_dist.trtllm_allreduce(t, op=dist.ReduceOp.SUM)
return trtllm_dist.trtllm_allreduce(t, op=dist.ReduceOp.SUM, strategy=strategy)
t_res = t.clone()
dist.all_reduce(t_res, op=dist.ReduceOp.SUM)
return t_res


@all_reduce.register_fake
def all_reduce_fake(tensor):
def all_reduce_fake(tensor, strategy):
return torch.empty_like(tensor)
8 changes: 4 additions & 4 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@ def simple_fake(input, weight, bias):
"auto_deploy::trtllm_dist_fused_linear_all_reduce", mutates_args=(), device_types="cuda"
)
def fused_linear_all_reduce(
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], strategy: str
) -> torch.Tensor:
"""Fused linear followed by all_reduce on the output."""
"""Fused linear followed by all_reduce on the output. Strategy is MANDATORY."""
output = torch.ops.aten.linear(input, weight, bias)
if trtllm_dist.is_trtllm_op_available():
return trtllm_dist.trtllm_allreduce(output, op=dist.ReduceOp.SUM)
return trtllm_dist.trtllm_allreduce(output, op=dist.ReduceOp.SUM, strategy=strategy)
dist.all_reduce(output, op=dist.ReduceOp.SUM)
return output


@fused_linear_all_reduce.register_fake
def fused_linear_all_reduce_fake(input, weight, bias):
def fused_linear_all_reduce_fake(input, weight, bias, strategy):
return torch.ops.aten.linear(input, weight, bias)
4 changes: 3 additions & 1 deletion tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def fp8_linear_fake(
def fused_fp8_linear_all_reduce(
input: torch.Tensor,
weight_fp8: torch.Tensor,
strategy: str,
bias: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None,
weight_scale: Optional[torch.Tensor] = None,
Expand All @@ -253,7 +254,7 @@ def fused_fp8_linear_all_reduce(
input, weight_fp8, bias, input_scale, weight_scale
)
if trtllm_dist.is_trtllm_op_available():
return trtllm_dist.trtllm_allreduce(out, op=dist.ReduceOp.SUM)
return trtllm_dist.trtllm_allreduce(out, op=dist.ReduceOp.SUM, strategy=strategy)
dist.all_reduce(out, op=dist.ReduceOp.SUM)
return out

Expand All @@ -262,6 +263,7 @@ def fused_fp8_linear_all_reduce(
def fused_fp8_linear_all_reduce_fake(
input: torch.Tensor,
weight_fp8: torch.Tensor,
strategy: str,
bias: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None,
weight_scale: Optional[torch.Tensor] = None,
Expand Down
37 changes: 28 additions & 9 deletions tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,26 @@ def trtllm_allgather(tensor, dim, sizes=None):
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
return allgather(tensor, p_config, dim=dim, sizes=sizes)

def trtllm_allreduce(tensor, op, all_reduce_params=None):
def trtllm_allreduce(tensor, op, strategy: str, all_reduce_params=None):
rank, world_size = get_rank_world_size()
assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op."

# Cache key includes rank, world_size, and dtype to handle different configurations
cache_key = (rank, world_size, tensor.dtype)
# Convert string strategy to enum
try:
strategy_enum = getattr(AllReduceStrategy, strategy)
except AttributeError:
raise ValueError(
f"Invalid allreduce strategy: {strategy}. "
f"Valid options: AUTO, NCCL, ONESHOT, TWOSHOT, MIN_LATENCY, "
f"LOWPRECISION, UB, MNNVL, NCCL_SYMMETRIC"
)

# Cache key includes rank, world_size, dtype, and strategy to handle different configurations
cache_key = (rank, world_size, tensor.dtype, strategy_enum)
if cache_key not in _allreduce_cache:
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
# Use Strategy.AUTO for optimal performance
_allreduce_cache[cache_key] = AllReduce(
mapping=p_config, strategy=AllReduceStrategy.NCCL, dtype=tensor.dtype
mapping=p_config, strategy=strategy_enum, dtype=tensor.dtype
)

torch_op = _allreduce_cache[cache_key]
Expand All @@ -38,7 +47,11 @@ def trtllm_allreduce(tensor, op, all_reduce_params=None):
"dist::fused_allreduce_residual_rmsnorm", mutates_args=(), device_types="cuda"
)
def fused_allreduce_residual_rmsnorm(
tensor: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, eps: float
tensor: torch.Tensor,
residual: torch.Tensor,
norm_weight: torch.Tensor,
eps: float,
strategy: str = "AUTO",
) -> tuple[torch.Tensor, torch.Tensor]:
"""Fusing allreduce, residual (add), and hf_rms_norm together.

Expand All @@ -54,7 +67,9 @@ def fused_allreduce_residual_rmsnorm(
norm_weight=norm_weight,
eps=eps,
)
return trtllm_allreduce(tensor, ReduceOp.SUM, all_reduce_params=all_reduce_params)
return trtllm_allreduce(
tensor, ReduceOp.SUM, strategy=strategy, all_reduce_params=all_reduce_params
)
else:
# Fallback: unfused implementation using torch distributed
# This is used in demollm mode without MPI
Expand All @@ -79,7 +94,11 @@ def fused_allreduce_residual_rmsnorm(

@fused_allreduce_residual_rmsnorm.register_fake
def fused_allreduce_residual_rmsnorm_fake(
tensor: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, eps: float
tensor: torch.Tensor,
residual: torch.Tensor,
norm_weight: torch.Tensor,
eps: float,
strategy: str = "AUTO",
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(tensor), torch.empty_like(tensor)

Expand All @@ -89,7 +108,7 @@ def fused_allreduce_residual_rmsnorm_fake(
def trtllm_allgather(tensor, dim, sizes=None):
raise ImportError("TRT-LLM is not available.")

def trtllm_allreduce(tensor, op):
def trtllm_allreduce(tensor, op, strategy: str, all_reduce_params=None):
raise ImportError("TRT-LLM is not available.")

TRTLLM_OP_AVAILABLE = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _allreduce_residual_rmsnorm_pattern(
"""

input_dtype = x.dtype
hidden_states = torch.ops.auto_deploy.torch_dist_all_reduce(x)
hidden_states = torch.ops.auto_deploy.torch_dist_all_reduce(x, "AUTO")
add = residual + hidden_states

hidden_states = add.to(torch.float32)
Expand All @@ -52,7 +52,7 @@ def _allreduce_residual_rmsnorm_pattern2(
"""

input_dtype = x.dtype
hidden_states = torch.ops.auto_deploy.torch_dist_all_reduce(x)
hidden_states = torch.ops.auto_deploy.torch_dist_all_reduce(x, "AUTO")
add = hidden_states + residual

hidden_states = add.to(torch.float32)
Expand Down
63 changes: 39 additions & 24 deletions tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
from typing import DefaultDict, Dict, List, Set, Tuple, Type

import torch
from pydantic import Field
from pydantic import Field, field_validator
from torch.fx import GraphModule, Node

from .....functional import AllReduceStrategy
from ...models.factory import ModelFactory, ShardingConfigSource
from ...shim.interface import CachedSequenceInterface
from ...utils.logger import ad_logger
Expand All @@ -49,6 +50,7 @@
SplitDimension,
WeightShardingInfo,
get_all_weights_in_subgraph,
validate_allreduce_strategy,
)
from ..interface import (
BaseTransform,
Expand Down Expand Up @@ -152,6 +154,18 @@ class ShardingTransformConfig(TransformConfig):
sharding_dims: List[ShardingDim] = Field(
default_factory=lambda: [ShardingDim.SSM, ShardingDim.TP, ShardingDim.EP, ShardingDim.BMM]
)
allreduce_strategy: AllReduceStrategy = Field(
default=AllReduceStrategy.AUTO,
description="AllReduce strategy for distributed operations. "
"Options: AUTO (automatic selection), NCCL, ONESHOT, TWOSHOT, MIN_LATENCY, "
"LOWPRECISION, UB, MNNVL, NCCL_SYMMETRIC",
)

@field_validator("allreduce_strategy", mode="before")
@classmethod
def _validate_allreduce_strategy(cls, v):
"""Convert string names like 'AUTO' to AllReduceStrategy enum."""
return validate_allreduce_strategy(v)


@TransformRegistry.register("detect_sharding")
Expand Down Expand Up @@ -199,6 +213,8 @@ def _apply(
sharding_config = shared_config.sharding_config
sharding_config.rank = local_rank
sharding_config.world_size = world_size
sharding_config.allreduce_strategy = self.config.allreduce_strategy
ad_logger.info(f"Using allreduce strategy: {sharding_config.allreduce_strategy.name}")
sharding_config.predefined_config = factory.get_sharding_config() if factory else {}
sharding_config.factory_source = (
sharding_config.predefined_config.get("source", ShardingConfigSource.UNKNOWN)
Expand Down Expand Up @@ -573,7 +589,7 @@ def detect_sharding_from_factory_config(
# we have a match. Get the config for this layer
config = tp_plan[key]
if config == "colwise":
sharding_config.weight_sharding_transforms.append(
if sharding_config.add(
WeightShardingInfo.from_node(
lin_node,
split_dim=SplitDimension.COLUMN,
Expand All @@ -582,10 +598,10 @@ def detect_sharding_from_factory_config(
dist_op=None,
min_local_shape=min_local_shape,
)
)
num_row_col_shards += 1
):
num_row_col_shards += 1
elif config == "rowwise":
sharding_config.weight_sharding_transforms.append(
if sharding_config.add(
WeightShardingInfo.from_node(
lin_node,
split_dim=SplitDimension.ROW,
Expand All @@ -594,10 +610,10 @@ def detect_sharding_from_factory_config(
dist_op="all_reduce",
min_local_shape=min_local_shape,
)
)
num_row_col_shards += 1
):
num_row_col_shards += 1
elif config == "mamba":
sharding_config.weight_sharding_transforms.append(
sharding_config.add(
WeightShardingInfo.from_node(
lin_node,
split_dim=SplitDimension.COLUMN,
Expand All @@ -618,7 +634,7 @@ def detect_sharding_from_factory_config(
if "shared" in module_name:
col_row_action = config.replace("local_", "")
if col_row_action == "colwise":
sharding_config.weight_sharding_transforms.append(
sharding_config.add(
WeightShardingInfo(
target_node=lin_node.name,
split_dim=SplitDimension.COLUMN,
Expand All @@ -629,7 +645,7 @@ def detect_sharding_from_factory_config(
)
)
elif col_row_action == "rowwise":
sharding_config.weight_sharding_transforms.append(
if sharding_config.add(
WeightShardingInfo(
target_node=lin_node.name,
split_dim=SplitDimension.ROW,
Expand All @@ -638,8 +654,8 @@ def detect_sharding_from_factory_config(
dist_op="all_reduce",
min_local_shape=min_local_shape,
)
)
num_row_col_shards += 1
):
num_row_col_shards += 1
else:
ad_logger.warning(f"Unsupported sharding action {config}. Skipping.")
else:
Expand All @@ -648,7 +664,7 @@ def detect_sharding_from_factory_config(

elif "gather" in config:
# Simple shard (row + all_gather)
sharding_config.weight_sharding_transforms.append(
if sharding_config.add(
WeightShardingInfo.from_node(
lin_node,
split_dim=SplitDimension.COLUMN,
Expand All @@ -657,13 +673,13 @@ def detect_sharding_from_factory_config(
dist_op="all_gather",
min_local_shape=1,
)
)
num_simple_shards += 1
):
num_simple_shards += 1
else:
ad_logger.warning(
f"Unsupported sharding action {config}. Fallback to simple shard"
)
sharding_config.weight_sharding_transforms.append(
sharding_config.add(
WeightShardingInfo.from_node(
lin_node,
split_dim=SplitDimension.COLUMN,
Expand Down Expand Up @@ -943,7 +959,7 @@ def detect_column_row_shard(
)

# shard single row node
sharding_config.weight_sharding_transforms.append(
if sharding_config.add(
WeightShardingInfo.from_node(
nodes_to_row_shard[0],
split_dim=SplitDimension.ROW,
Expand All @@ -952,9 +968,8 @@ def detect_column_row_shard(
dist_op="all_reduce",
min_local_shape=min_local_shape,
)
)

num_row_col_shards += 1
):
num_row_col_shards += 1

ad_logger.info(
f"Found {num_shards} TP shards (simple: {num_simple_shards}, row-col: {num_row_col_shards})"
Expand Down Expand Up @@ -1020,7 +1035,7 @@ def detect_dp_bmm_shard(gm: GraphModule, sharding_config: ShardingConfig) -> Tra
start_idx = remainder + rank * base_size
end_idx = start_idx + base_size

sharding_config.bmm_transforms.append(
sharding_config.add(
BMMShardingInfo(
target_node=node.name,
rank=rank,
Expand Down Expand Up @@ -1064,14 +1079,14 @@ def detect_ep_shard(gm: GraphModule, sharding_config: ShardingConfig) -> Transfo
),
):
continue
sharding_config.ep_transforms.append(
if sharding_config.add(
EPShardingInfo.from_node(
node,
rank=rank,
world_size=world_size,
)
)
num_moe_patterns += 1
):
num_moe_patterns += 1

ad_logger.info(f"Found {num_moe_patterns} MoE patterns")

Expand Down
Loading