Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,11 @@ transforms:
fuse_allreduce_residual_rmsnorm:
stage: post_load_fusion
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
# check if we can fuse rmsnorm
fuse_rmsnorm:
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
# check if we can fuse rmsnorm
stage: post_load_fusion
backend: triton
rmsnorm_backend: triton
gated_rmsnorm_backend: triton
requires_shape_prop: true
fuse_gated_rmsnorm:
stage: post_load_fusion

############################################################################################
# VISUALIZE GRAPH
Expand Down
8 changes: 4 additions & 4 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def _(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
return torch.empty_like(input)


@torch.library.custom_op("auto_deploy::torch_rmsnorm_gated", mutates_args=())
def torch_rmsnorm_gated(
@torch.library.custom_op("auto_deploy::triton_rmsnorm_gated", mutates_args=())
def triton_rmsnorm_gated(
x: torch.Tensor,
weight: torch.Tensor,
gate: torch.Tensor | None,
Expand Down Expand Up @@ -140,8 +140,8 @@ def torch_rmsnorm_gated(
return out2.reshape(x_shape)


@torch_rmsnorm_gated.register_fake
def _torch_rmsnorm_gated_meta(
@triton_rmsnorm_gated.register_fake
def _triton_rmsnorm_gated_meta(
x,
weight,
gate,
Expand Down
117 changes: 52 additions & 65 deletions tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,23 +90,28 @@ def _rms_norm_replacement(
class FuseRMSNormConfig(TransformConfig):
"""Configuration for the RMSNorm fusion transform."""

backend: str = Field(
rmsnorm_backend: str = Field(
default="flashinfer",
description="Backend to use for RMSNorm computation ('flashinfer' or 'triton').",
description="Backend to use for RMSNorm computation ('flashinfer', 'triton', or 'torch').",
)
gated_rmsnorm_backend: str = Field(
default="triton",
description="Backend to use for gated RMSNorm computation (currently only 'triton').",
)


@TransformRegistry.register("fuse_rmsnorm")
class FuseRMSNorm(BaseTransform):
"""Matches and replaces RMSNorm patterns in the graph with FlashInfer or Triton implementation.
"""Matches and replaces RMSNorm patterns (regular and gated) in the graph with optimized implementations.

This function sets up pattern matching to identify RMSNorm operations in the graph
This function sets up pattern matching to identify both regular and gated RMSNorm operations in the graph
and replaces them with optimized implementations. It uses dummy tensors to register
the pattern matching rules.

Args:
gm: Input graph module to transform.
backend: Backend to use for RMSNorm computation ("flashinfer" or "triton").
rmsnorm_backend: Backend to use for regular RMSNorm computation ("flashinfer", "triton", or "torch").
gated_rmsnorm_backend: Backend to use for gated RMSNorm computation (currently only "triton").

Returns:
Transformed graph module with optimized RMSNorm operations.
Expand All @@ -125,15 +130,23 @@ def _apply(
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
if self.config.backend.lower() not in _BACKEND_OPS:
# Validate rmsnorm_backend
if self.config.rmsnorm_backend.lower() not in _BACKEND_OPS:
raise ValueError(
f"Invalid rmsnorm_backend, must be one of {list(_BACKEND_OPS)}, got {self.config.rmsnorm_backend}"
)

# Validate gated_rmsnorm_backend (currently only triton is supported)
if self.config.gated_rmsnorm_backend.lower() != "triton":
raise ValueError(
f"Invalid backend, must be one of {list(_BACKEND_OPS)}, got {self.config.backend}"
f"""Invalid gated_rmsnorm_backend, currently only 'triton' is supported,
got {self.config.gated_rmsnorm_backend}"""
)

graph = gm.graph
patterns = ADPatternMatcherPass()

# Create dummy tensors for pattern matching
# Pattern matching for regular RMSNorm
bs = 2
hidden_size = 512

Expand All @@ -160,13 +173,42 @@ def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float =
for input_dtype, weight_dtype in configs:
register_ad_pattern(
search_fn=search_fn,
replace_fn=partial(_rms_norm_replacement, backend=self.config.backend),
replace_fn=partial(_rms_norm_replacement, backend=self.config.rmsnorm_backend),
patterns=patterns,
dummy_args=dummy_args(input_dtype, weight_dtype),
op_ignore_types={},
scalar_workaround={"eps": 1e-6},
)

# Pattern matching for gated RMSNorm
B, S, H = 2, 3, 4096
group_size = 512
eps = 1e-5

def make_dummy_args_gated(group_size: int, eps: float) -> list:
x = torch.randn(B, S, H, dtype=torch.float32)
w = torch.randn(H, dtype=torch.float32)
g = torch.randn(B, S, H, dtype=torch.float32)
return [x, w, g, eps, group_size]

op_ignore_types = {
torch.ops.aten.reshape.default: (int, list, tuple),
torch.ops.aten.view.default: (int, list, tuple),
torch.ops.aten.mean.dim: (list, tuple),
torch.ops.aten.to.dtype: (torch.dtype,),
}

# Register pattern for gated RMSNorm
register_ad_pattern(
search_fn=_gated_rmsnorm_pattern_ref,
replace_fn=_gated_rmsnorm_replacement,
patterns=patterns,
dummy_args=make_dummy_args_gated(group_size, eps),
op_ignore_types=op_ignore_types,
scalar_workaround={"eps": eps, "group_size": group_size},
skip_duplicates=True,
)

cnt = patterns.apply(graph)

info = TransformInfo(
Expand Down Expand Up @@ -204,61 +246,6 @@ def _gated_rmsnorm_replacement(
eps: float,
group_size: int,
) -> torch.Tensor:
return torch.ops.auto_deploy.torch_rmsnorm_gated(
return torch.ops.auto_deploy.triton_rmsnorm_gated(
x, weight, gate, float(eps), int(group_size), False
)


@TransformRegistry.register("fuse_gated_rmsnorm")
class FuseGatedRMSNorm(BaseTransform):
"""
Fuse the NemotronH-style gated RMSNorm subgraph into a single custom op:
auto_deploy::torch_rmsnorm_gated(x, weight, gate, eps, group_size, norm_before_gate=False)
"""

def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
graph = gm.graph
patterns = ADPatternMatcherPass()

B, S, H = 2, 3, 4096
group_size = 512
eps = 1e-5

def make_dummy_args(group_size: int, eps: float) -> list:
x = torch.randn(B, S, H, dtype=torch.float32)
w = torch.randn(H, dtype=torch.float32)
g = torch.randn(B, S, H, dtype=torch.float32)
return [x, w, g, eps, group_size]

op_ignore_types = {
torch.ops.aten.reshape.default: (int, list, tuple),
torch.ops.aten.view.default: (int, list, tuple),
torch.ops.aten.mean.dim: (list, tuple),
torch.ops.aten.to.dtype: (torch.dtype,),
}

register_ad_pattern(
search_fn=_gated_rmsnorm_pattern_ref,
replace_fn=partial(_gated_rmsnorm_replacement),
patterns=patterns,
dummy_args=make_dummy_args(group_size, eps),
op_ignore_types=op_ignore_types,
scalar_workaround={"eps": eps, "group_size": group_size},
skip_duplicates=True,
)

num = patterns.apply(graph)

info = TransformInfo(
skipped=False,
num_matches=num,
is_clean=num == 0,
has_valid_shapes=num == 0,
)
return gm, info
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_custom_op_matches_ref(B, T, H, group, use_gate, dtype):
)

# Custom op (currently returns fp32). Cast it back to x.dtype for apples-to-apples with ref.
y_op_fp32 = torch.ops.auto_deploy.torch_rmsnorm_gated(x, w, z, 1e-5, group, False)
y_op_fp32 = torch.ops.auto_deploy.triton_rmsnorm_gated(x, w, z, 1e-5, group, False)
y_op = y_op_fp32.to(x.dtype)

assert y_ref.dtype == x.dtype and y_op.dtype == x.dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def checker(gm):
{
"fuse_rmsnorm": {
"stage": "post_load_fusion",
"backend": variant,
"gated_rmsnorm_backend": "triton",
"rmsnorm_backend": variant,
},
},
)(None, gm)
Expand Down Expand Up @@ -102,4 +103,4 @@ def test_rmsnorm_fusion(eps, variant, op):
def test_rmsnorm_fusion_nemotron_h():
# Only the triton backend supports the nemotron h rmsnorm
model = TestModel(eps=1e-6, use_nemotron_h=True)
_run_test(model, torch.ops.auto_deploy.triton_rms_norm, "triton")
_run_test(model, torch.ops.auto_deploy.triton_rms_norm, variant="triton")