Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 36 additions & 9 deletions tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,42 @@ def trtllm_allreduce(tensor, op, all_reduce_params=None):
def fused_allreduce_residual_rmsnorm(
tensor: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, eps: float
) -> tuple[torch.Tensor, torch.Tensor]:
"""Fusing allreduce, residual (add), and hf_rms_norm together."""
all_reduce_params = AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
bias=None,
residual=residual,
norm_weight=norm_weight,
eps=eps,
)
return trtllm_allreduce(tensor, ReduceOp.SUM, all_reduce_params=all_reduce_params)
"""Fusing allreduce, residual (add), and hf_rms_norm together.

When TRT-LLM ops are available (MPI mode), uses the fused kernel.
Otherwise, falls back to separate operations using torch distributed.
"""
# Only use TRT-LLM fused op when running with MPI
if is_trtllm_op_available():
all_reduce_params = AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
bias=None,
residual=residual,
norm_weight=norm_weight,
eps=eps,
)
return trtllm_allreduce(tensor, ReduceOp.SUM, all_reduce_params=all_reduce_params)
else:
# Fallback: unfused implementation using torch distributed
# This is used in demollm mode without MPI
from .common import all_reduce as torch_all_reduce

# 1. All-reduce the tensor
tensor_reduced = tensor.clone()
torch_all_reduce(tensor_reduced, op=ReduceOp.SUM)

# 2. Add residual
tensor_with_residual = tensor_reduced + residual

# 3. Apply RMSNorm using PyTorch's built-in function
norm_out = torch.nn.functional.rms_norm(
tensor_with_residual,
normalized_shape=(tensor_with_residual.size(-1),),
weight=norm_weight,
eps=eps,
)

return norm_out, tensor_with_residual

@fused_allreduce_residual_rmsnorm.register_fake
def fused_allreduce_residual_rmsnorm_fake(
Expand Down