diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 08cab0587e4..7c4a80044e4 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -1,6 +1,7 @@ #!/bin/bash # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -94,11 +95,6 @@ BUILD_DOCS=1 # Copy requirements-lintrunner.txt from root to here cp ../../requirements-lintrunner.txt ./ -# Copy arm setup script from root to here -# TODO(huydhn): Figure out a way to rebuild the Docker image automatically -# with a new image hash when the content here is updated -cp -r ../../examples/arm/ ./arm - docker build \ --no-cache \ --progress=plain \ diff --git a/backends/aoti/slim/c10/core/ScalarType.h b/backends/aoti/slim/c10/core/ScalarType.h index c1499a83f39..9a99aecf992 100644 --- a/backends/aoti/slim/c10/core/ScalarType.h +++ b/backends/aoti/slim/c10/core/ScalarType.h @@ -28,7 +28,7 @@ enum class ScalarType : int8_t { Short = 2, // int16_t Int = 3, // int32_t Long = 4, // int64_t - // Half = 5, // float16 - not currently needed + Half = 5, // float16 Float = 6, // float // Double = 7, // double - not currently needed // ComplexHalf = 8, @@ -48,6 +48,7 @@ constexpr ScalarType kChar = ScalarType::Char; constexpr ScalarType kShort = ScalarType::Short; constexpr ScalarType kInt = ScalarType::Int; constexpr ScalarType kLong = ScalarType::Long; +constexpr ScalarType kHalf = ScalarType::Half; constexpr ScalarType kFloat = ScalarType::Float; constexpr ScalarType kBool = ScalarType::Bool; constexpr ScalarType kBFloat16 = ScalarType::BFloat16; @@ -67,6 +68,8 @@ inline size_t elementSize(ScalarType t) { return sizeof(int32_t); case ScalarType::Long: return sizeof(int64_t); + case ScalarType::Half: + return 2; // sizeof(__half) = 2 bytes case ScalarType::Float: return sizeof(float); case ScalarType::Bool: @@ -93,6 +96,8 @@ inline const char* toString(ScalarType t) { return "Int"; case ScalarType::Long: return "Long"; + case ScalarType::Half: + return "Half"; case ScalarType::Float: return "Float"; case ScalarType::Bool: @@ -110,7 +115,8 @@ inline const char* toString(ScalarType t) { /// @param t The scalar type to check. /// @return true if the scalar type is floating point, false otherwise. inline bool isFloatingType(ScalarType t) { - return t == ScalarType::Float || t == ScalarType::BFloat16; + return t == ScalarType::Half || t == ScalarType::Float || + t == ScalarType::BFloat16; } /// Checks if the scalar type is an integral type (including bool optionally). @@ -149,6 +155,7 @@ inline bool isValidScalarType(ScalarType t) { case ScalarType::Short: case ScalarType::Int: case ScalarType::Long: + case ScalarType::Half: case ScalarType::Float: case ScalarType::Bool: case ScalarType::BFloat16: diff --git a/backends/aoti/slim/c10/core/test/test_scalar_type.cpp b/backends/aoti/slim/c10/core/test/test_scalar_type.cpp index 332f5d7d264..4c06f7ef101 100644 --- a/backends/aoti/slim/c10/core/test/test_scalar_type.cpp +++ b/backends/aoti/slim/c10/core/test/test_scalar_type.cpp @@ -36,6 +36,7 @@ const std::vector kAllScalarTypes = { {ScalarType::Short, 2, 2, "Short", false, true, true, false}, {ScalarType::Int, 3, 4, "Int", false, true, true, false}, {ScalarType::Long, 4, 8, "Long", false, true, true, false}, + {ScalarType::Half, 5, 2, "Half", true, false, false, false}, {ScalarType::Float, 6, 4, "Float", true, false, false, false}, {ScalarType::Bool, 11, 1, "Bool", false, false, true, true}, {ScalarType::BFloat16, 15, 2, "BFloat16", true, false, false, false}, @@ -128,6 +129,10 @@ TEST_F(ScalarTypeConstantsTest, KLongConstant) { EXPECT_EQ(kLong, ScalarType::Long); } +TEST_F(ScalarTypeConstantsTest, KHalfConstant) { + EXPECT_EQ(kHalf, ScalarType::Half); +} + TEST_F(ScalarTypeConstantsTest, KFloatConstant) { EXPECT_EQ(kFloat, ScalarType::Float); } @@ -185,6 +190,10 @@ TEST_F(ElementSizeConsistencyTest, LongMatchesSizeofInt64) { EXPECT_EQ(elementSize(ScalarType::Long), sizeof(int64_t)); } +TEST_F(ElementSizeConsistencyTest, HalfIs2Bytes) { + EXPECT_EQ(elementSize(ScalarType::Half), 2); +} + TEST_F(ElementSizeConsistencyTest, FloatMatchesSizeofFloat) { EXPECT_EQ(elementSize(ScalarType::Float), sizeof(float)); } @@ -196,3 +205,29 @@ TEST_F(ElementSizeConsistencyTest, BoolMatchesSizeofBool) { TEST_F(ElementSizeConsistencyTest, BFloat16MatchesSizeofBFloat16) { EXPECT_EQ(elementSize(ScalarType::BFloat16), sizeof(BFloat16)); } + +// ============================================================================= +// isValidScalarType Tests +// ============================================================================= + +class IsValidScalarTypeTest : public ::testing::Test {}; + +TEST_F(IsValidScalarTypeTest, HalfIsValid) { + EXPECT_TRUE(isValidScalarType(ScalarType::Half)); +} + +TEST_F(IsValidScalarTypeTest, AllSupportedTypesAreValid) { + EXPECT_TRUE(isValidScalarType(ScalarType::Byte)); + EXPECT_TRUE(isValidScalarType(ScalarType::Char)); + EXPECT_TRUE(isValidScalarType(ScalarType::Short)); + EXPECT_TRUE(isValidScalarType(ScalarType::Int)); + EXPECT_TRUE(isValidScalarType(ScalarType::Long)); + EXPECT_TRUE(isValidScalarType(ScalarType::Half)); + EXPECT_TRUE(isValidScalarType(ScalarType::Float)); + EXPECT_TRUE(isValidScalarType(ScalarType::Bool)); + EXPECT_TRUE(isValidScalarType(ScalarType::BFloat16)); +} + +TEST_F(IsValidScalarTypeTest, UndefinedIsNotValid) { + EXPECT_FALSE(isValidScalarType(ScalarType::Undefined)); +} diff --git a/backends/arm/_passes/fuse_constant_ops_pass.py b/backends/arm/_passes/fuse_constant_ops_pass.py index 6fd9b145988..f54ed851240 100644 --- a/backends/arm/_passes/fuse_constant_ops_pass.py +++ b/backends/arm/_passes/fuse_constant_ops_pass.py @@ -4,7 +4,8 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Set, Type +from collections.abc import Mapping +from typing import Sequence, Set, Type import torch._export.utils import torch.fx @@ -18,6 +19,7 @@ from executorch.backends.arm._passes.fuse_equal_placeholders_pass import ( FuseEqualPlaceholdersPass, ) +from executorch.backends.arm.tosa.dialect.shape import meta_has_shape_mark from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.backends.transforms.utils import ( create_constant_placeholder, @@ -53,6 +55,36 @@ def __init__(self, exported_program: ExportedProgram, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.exported_program = exported_program + @staticmethod + def _is_tosa_dialect_op(target) -> bool: + target_str = str(target) + return ( + "executorch.exir.dialects.backend._ops.tosa." in target_str + or " bool: + if isinstance(arg, torch.fx.Node): + if meta_has_shape_mark(arg.meta): + return True + return FuseConstantArgsPass._arg_contains_symbolic_shape( + arg.meta.get("val") + ) + if isinstance(arg, torch.SymInt): + return True + if isinstance(arg, Mapping): + return any( + FuseConstantArgsPass._arg_contains_symbolic_shape(k) + or FuseConstantArgsPass._arg_contains_symbolic_shape(v) + for k, v in arg.items() + ) + if isinstance(arg, Sequence) and not isinstance(arg, (str, bytes)): + return any( + FuseConstantArgsPass._arg_contains_symbolic_shape(v) for v in arg + ) + return False + def _propagate_special_dtype(self, from_nodes, to_node, data): """Propagate special dtype meta if it exists.""" special_dtypes = set() @@ -83,21 +115,24 @@ def _fuse_nodes(self, node) -> bool: input_nodes = list(node.all_input_nodes) qparams = node.meta.get("input_qparams", None) - def resolve_arg(arg): + def resolve_arg(arg, arg_index=None): + qparam = ( + qparams.get(arg_index) if qparams and arg_index is not None else None + ) if isinstance(arg, torch.fx.Node) and arg in input_nodes: - idx = input_nodes.index(arg) t = get_param_tensor(self.exported_program, arg) - # Check if qparams exist for this arg - if qparams and idx in qparams.keys(): - t = qparams[idx].dequantize_value(t) + if qparam is not None: + t = qparam.dequantize_value(t) return t if isinstance(arg, tuple): - return tuple(resolve_arg(x) for x in arg) + return tuple(resolve_arg(x, arg_index) for x in arg) if isinstance(arg, list): - return [resolve_arg(x) for x in arg] + return [resolve_arg(x, arg_index) for x in arg] return arg - new_args = tuple(resolve_arg(a) for a in node.args) + new_args = tuple( + resolve_arg(arg, arg_index) for arg_index, arg in enumerate(node.args) + ) new_kwargs = {k: resolve_arg(v) for k, v in node.kwargs.items()} data = node.target(*new_args, **new_kwargs) @@ -139,13 +174,13 @@ def call(self, graph_module): for node in graph_module.graph.nodes: if node.op != "call_function": continue - if node.target in [ - exir_ops.backend.tosa.MATMUL.default, - exir_ops.backend.tosa.RESCALE.default, - exir_ops.backend.tosa.RESIZE.default, - exir_ops.backend.tosa.TABLE.default, - exir_ops.backend.tosa.TRANSPOSE.default, - ]: + # Don't fuse TOSA dialect ops as they do not have eager forward functions. + # Also don't fuse ops whose explicit args/kwargs include symbolic shape values. + if ( + self._is_tosa_dialect_op(node.target) + or self._arg_contains_symbolic_shape(node.args) + or self._arg_contains_symbolic_shape(node.kwargs) + ): continue input_nodes = node.all_input_nodes @@ -161,7 +196,6 @@ def call(self, graph_module): ) if not all(input_nodes_constant): continue - try: did_fuse = self._fuse_nodes(node) if did_fuse: diff --git a/backends/arm/_passes/rewrite_conv_pass.py b/backends/arm/_passes/rewrite_conv_pass.py index e4be0b5dc25..8244dc2558b 100644 --- a/backends/arm/_passes/rewrite_conv_pass.py +++ b/backends/arm/_passes/rewrite_conv_pass.py @@ -21,6 +21,9 @@ get_input_qparams, get_output_qparams, ) +from executorch.backends.arm._passes.symbolic_value_range import ( + evaluate_symbolic_expr_values, +) from executorch.backends.arm.constants import HWCM_ORDER, NHWC_INVERSE_ORDER from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.backends.arm.tosa.specification import get_context_shape_env @@ -83,8 +86,14 @@ def _adjust_pad_if_needed( if isinstance(mod_remainder, torch.SymInt): shape_env = get_context_shape_env() - value_ranges = shape_env.bound_sympy(mod_remainder.node.expr) - mod_remainder_upper = int(value_ranges.upper) + exact_values = evaluate_symbolic_expr_values( + mod_remainder.node.expr, shape_env + ) + if exact_values is not None: + mod_remainder_upper = max(exact_values) + else: + value_ranges = shape_env.bound_sympy(mod_remainder.node.expr) + mod_remainder_upper = int(value_ranges.upper) if mod_remainder_upper == 0: mod_remainder = 0 else: @@ -92,7 +101,7 @@ def _adjust_pad_if_needed( if mod_remainder_upper > pad: raise RuntimeError( - "This case should be handled by the SizeAdjustInputPass, is it enabled?" + "This case should be handled by the SizeAdjustInputPass, is it enabled?\n" ) return pad - mod_remainder diff --git a/backends/arm/_passes/size_adjust_input_pass.py b/backends/arm/_passes/size_adjust_input_pass.py index 233c93340b8..bf50306f5d6 100644 --- a/backends/arm/_passes/size_adjust_input_pass.py +++ b/backends/arm/_passes/size_adjust_input_pass.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from typing import cast, Sequence, Set, Type, TypeAlias +import torch import torch.fx from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( @@ -12,6 +13,9 @@ ) from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass from executorch.backends.arm._passes.rewrite_max_pool2d_pass import RewriteMaxPool2dPass +from executorch.backends.arm._passes.symbolic_value_range import ( + evaluate_symbolic_expr_values, +) from executorch.backends.arm.tosa.specification import get_context_shape_env from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -49,6 +53,9 @@ def _greater_than(input: SymIntLike, other: int) -> bool | torch.SymBool: """Returns whether an int or SymInt is greater than another value.""" if isinstance(input, torch.SymInt): shape_env = get_context_shape_env() + exact_values = evaluate_symbolic_expr_values(input.node.expr, shape_env) + if exact_values is not None: + return max(exact_values) > other value_ranges = shape_env.bound_sympy(input.node.expr) return value_ranges.upper > other else: diff --git a/backends/arm/_passes/symbolic_value_range.py b/backends/arm/_passes/symbolic_value_range.py new file mode 100644 index 00000000000..0753fefa270 --- /dev/null +++ b/backends/arm/_passes/symbolic_value_range.py @@ -0,0 +1,138 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import sympy # type: ignore[import-untyped] +import torch +from torch.fx.experimental.symbolic_shapes import ShapeEnv +from torch.utils._sympy.interp import sympy_interp + +_MAX_SET_SIZE = 256 +_ExactValues = Optional[frozenset[sympy.Basic]] + + +def _expr_to_int(sym_expr: sympy.Basic) -> Optional[int]: + if isinstance(sym_expr, int): + return sym_expr + if isinstance(sym_expr, sympy.Integer): + return int(sym_expr) + if getattr(sym_expr, "is_integer", False) and sym_expr.is_number: + return int(sym_expr) + return None + + +def _symbol_values(symbol: sympy.Symbol, shape_env: ShapeEnv) -> _ExactValues: + value_range = shape_env.var_to_range.get(symbol) + if value_range is None or not value_range.is_int: + return None + + lower = _expr_to_int(value_range.lower) + upper = _expr_to_int(value_range.upper) + if lower is None or upper is None or upper < lower: + return None + if upper - lower + 1 > _MAX_SET_SIZE: + return None + + return frozenset(sympy.Integer(value) for value in range(lower, upper + 1)) + + +def _map_values(values: _ExactValues, fn) -> _ExactValues: + if values is None: + return None + + result = {sympy.simplify(fn(value)) for value in values} + if len(result) > _MAX_SET_SIZE: + return None + return frozenset(result) + + +def _combine_values(lhs: _ExactValues, rhs: _ExactValues, fn) -> _ExactValues: + if lhs is None or rhs is None: + return None + if len(lhs) * len(rhs) > _MAX_SET_SIZE * _MAX_SET_SIZE: + return None + + result = {sympy.simplify(fn(a, b)) for a in lhs for b in rhs} + if len(result) > _MAX_SET_SIZE: + return None + return frozenset(result) + + +class _ExactValueAnalysis: + @staticmethod + def constant(value, dtype) -> frozenset[sympy.Basic]: + return frozenset({sympy.sympify(value)}) + + @staticmethod + def add(lhs: _ExactValues, rhs: _ExactValues) -> _ExactValues: + return _combine_values(lhs, rhs, lambda a, b: a + b) + + @staticmethod + def mul(lhs: _ExactValues, rhs: _ExactValues) -> _ExactValues: + return _combine_values(lhs, rhs, lambda a, b: a * b) + + @staticmethod + def mod(lhs: _ExactValues, rhs: _ExactValues) -> _ExactValues: + if rhs is None or any(value == 0 for value in rhs): + return None + return _combine_values(lhs, rhs, lambda a, b: sympy.Mod(a, b)) + + @staticmethod + def pow(lhs: _ExactValues, rhs: _ExactValues) -> _ExactValues: + return _combine_values(lhs, rhs, lambda a, b: a**b) + + @staticmethod + def floor_to_int(values: _ExactValues, dtype) -> _ExactValues: + return _map_values(values, sympy.floor) + + @staticmethod + def sym_sum(args: list[_ExactValues]) -> _ExactValues: + acc: _ExactValues = frozenset({sympy.Integer(0)}) + for arg in args: + acc = _ExactValueAnalysis.add(acc, arg) + if acc is None: + return None + return acc + + +def evaluate_symbolic_expr_values( + expr: sympy.Basic | torch.SymInt, + shape_env: ShapeEnv, +) -> Optional[set[int]]: + """Return a best-effort finite set of possible integer values. + + The helper first relies on ``bound_sympy`` for cheap singleton detection. + When interval bounds are not precise enough, it falls back to a small + exact-set analysis over bounded symbols using ``sympy_interp``. + + """ + root_expr = sympy.simplify( + expr.node.expr if isinstance(expr, torch.SymInt) else expr + ) + value_range = shape_env.bound_sympy(root_expr) + if value_range.is_int and value_range.is_singleton(): + singleton = _expr_to_int(value_range.lower) + return {singleton} if singleton is not None else None + + exact_values = sympy_interp( + _ExactValueAnalysis, + { + symbol: _symbol_values(symbol, shape_env) + for symbol in root_expr.free_symbols + }, + root_expr, + missing_handler=lambda symbol: _symbol_values(symbol, shape_env), + ) + if exact_values is None: + return None + + result: set[int] = set() + for value in exact_values: + integer_value = _expr_to_int(sympy.simplify(value)) + if integer_value is None: + return None + result.add(integer_value) + return result diff --git a/backends/arm/test/passes/test_fuse_constant_ops_pass.py b/backends/arm/test/passes/test_fuse_constant_ops_pass.py index 0f281dba24b..d915b4ecba0 100644 --- a/backends/arm/test/passes/test_fuse_constant_ops_pass.py +++ b/backends/arm/test/passes/test_fuse_constant_ops_pass.py @@ -6,13 +6,25 @@ import operator from typing import cast, ClassVar, Dict, Protocol, Tuple +import executorch.backends.arm.tosa.dialect # noqa: F401 import torch from executorch.backends.arm._passes.fuse_constant_ops_pass import ( ComputeConstantOpsAOTPass, FuseConstantArgsPass, ) +from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.backends.arm.test.tester.test_pipeline import PassPipeline +from executorch.backends.arm.tosa.mapping import TosaSpecialDtype +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) +from executorch.backends.test.harness.stages import StageType +from executorch.backends.test.program_builder import ProgramBuilder +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export.graph_signature import InputKind input_t = Tuple[torch.Tensor] # Input x input_t2 = Tuple[torch.Tensor, torch.Tensor] @@ -116,6 +128,52 @@ def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return torch.cat((a, b), dim=0) +class QuantizedCatConstantBuffers(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer( + "horizontal_ramp", + torch.tensor( + [ + [ + [ + [-95, -32, 32, 95, 0], + [-95, -32, 32, 95, 0], + [-95, -32, 32, 95, 0], + [-95, -32, 32, 95, 0], + ] + ] + ], + dtype=torch.int8, + ), + ) + self.register_buffer( + "vertical_ramp", + torch.tensor( + [ + [ + [ + [-95, -95, -95, -95, -95], + [-32, -32, -32, -32, -32], + [32, 32, 32, 32, 32], + [95, 95, 95, 95, 95], + ] + ] + ], + dtype=torch.int8, + ), + ) + + def forward(self) -> torch.Tensor: + return torch.cat( + ( + cast(torch.Tensor, self.horizontal_ramp), + cast(torch.Tensor, self.vertical_ramp), + ), + dim=1, + ) + + modules: Dict[str, ModuleWithFuseAttrs] = { "fuse_parameter": cast(ModuleWithFuseAttrs, FuseParameter()), "fuse_buffer": cast(ModuleWithFuseAttrs, FuseBuffer()), @@ -174,3 +232,116 @@ def test_fuse_constant_args_tosa_INT_cat(module: ModuleWithFuseAttrs) -> None: ], ) pipeline.run() + + +def test_fuse_constant_args_tosa_INT_cat_uses_top_level_arg_qparams() -> None: + qargs = QuantArgs( + scale=1.0 / 127.0, + zp=0, + qmin=-127, + qmax=127, + dtype=torch.int8, + ) + module = QuantizedCatConstantBuffers() + compile_spec = common.get_tosa_compile_spec( + TosaSpecification.create_from_string("TOSA-1.0+FP") + ) + tester = ArmTester(module, example_inputs=(), compile_spec=compile_spec) + tester.export().to_edge() + exported_program = tester.get_artifact(StageType.TO_EDGE).exported_program() + + cat_node = next( + node + for node in exported_program.graph_module.graph.nodes + if node.op == "call_function" + ) + cat_node.meta["input_qparams"] = {0: qargs} + cat_node.meta["output_qparams"] = {0: qargs} + + pass_result = FuseConstantArgsPass(exported_program).call( + exported_program.graph_module + ) + + assert list(exported_program.state_dict) == ["aten_cat_default_fused_const"] + torch.testing.assert_close( + exported_program.state_dict["aten_cat_default_fused_const"], + torch.cat( + ( + cast(torch.Tensor, module.horizontal_ramp), + cast(torch.Tensor, module.vertical_ramp), + ), + dim=1, + ), + ) + assert [ + node.name + for node in pass_result.graph_module.graph.nodes + if node.op == "placeholder" + ] == ["aten_cat_default_fused_const"] + + +def test_fuse_constant_args_identifies_tosa_dialect_targets() -> None: + class FakeTosaTarget: + def __str__(self) -> str: + return "executorch.exir.dialects.backend._ops.tosa.MAX_POOL2D.default" + + assert FuseConstantArgsPass._is_tosa_dialect_op(FakeTosaTarget()) + assert FuseConstantArgsPass._is_tosa_dialect_op( + exir_ops.backend.tosa.GATHER.default + ) + assert not FuseConstantArgsPass._is_tosa_dialect_op(torch.ops.aten.add.Tensor) + + +def test_fuse_constant_args_identifies_symbolic_shape_args() -> None: + graph = torch.fx.Graph() + shape_node = graph.placeholder("shape") + shape_node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.SHAPE + + assert FuseConstantArgsPass._arg_contains_symbolic_shape((shape_node, [1, 2])) + assert not FuseConstantArgsPass._arg_contains_symbolic_shape( + ([1, 2], {"pad": (0, 0)}) + ) + + +def test_fuse_constant_args_skips_backend_tosa_gather(caplog) -> None: + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.1+FP+shape")): + builder = ProgramBuilder() + values = builder.placeholder( + "values", + torch.randn(1, 4, 3), + input_kind=InputKind.CONSTANT_TENSOR, + ) + indices = builder.placeholder( + "indices", + torch.tensor([[0, 2]], dtype=torch.int32), + input_kind=InputKind.CONSTANT_TENSOR, + ) + gather = builder.call_operator( + exir_ops.backend.tosa.GATHER.default, + (values, indices), + ) + builder.output([gather]) + + exported_program = builder.get_program() + graph_module = exported_program.graph_module + + with caplog.at_level("WARNING"): + FuseConstantArgsPass(exported_program)(graph_module) + + warning_messages = [ + record.getMessage() + for record in caplog.records + if record.name == "executorch.backends.arm._passes.fuse_constant_ops_pass" + ] + assert not any( + "Failed to fuse constant op" in message and "GATHER" in message + for message in warning_messages + ) + assert ( + sum( + node.op == "call_function" + and node.target == exir_ops.backend.tosa.GATHER.default + for node in graph_module.graph.nodes + ) + == 1 + ) diff --git a/backends/arm/test/passes/test_symbolic_value_range.py b/backends/arm/test/passes/test_symbolic_value_range.py new file mode 100644 index 00000000000..8d3c970f0ab --- /dev/null +++ b/backends/arm/test/passes/test_symbolic_value_range.py @@ -0,0 +1,69 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sympy # type: ignore[import-untyped] +import torch +from executorch.backends.arm._passes.symbolic_value_range import ( + evaluate_symbolic_expr_values, +) +from torch.fx.experimental.symbolic_shapes import ShapeEnv + + +def _make_shape_env( + *, + symbol_name: str = "s89", + hint: int = 2, + compiler_min: int = 1, + compiler_max: int = 2, +) -> tuple[ShapeEnv, torch.SymInt]: + shape_env = ShapeEnv() + symint = shape_env.create_symintnode(sympy.Symbol(symbol_name), hint=hint) + shape_env.constrain_symbol_range( + symint.node.expr, + compiler_min=compiler_min, + compiler_max=compiler_max, + ) + return shape_env, symint + + +def test_evaluate_symbolic_expr_values_returns_singleton_for_constant_expr() -> None: + shape_env, symint = _make_shape_env() + + assert evaluate_symbolic_expr_values( + symint.node.expr - symint.node.expr, shape_env + ) == {0} + assert evaluate_symbolic_expr_values( + sympy.floor(symint.node.expr / symint.node.expr), shape_env + ) == {1} + + +def test_evaluate_symbolic_expr_values_returns_singleton_for_singleton_symint() -> None: + shape_env, symint = _make_shape_env(hint=3, compiler_min=3, compiler_max=3) + + assert evaluate_symbolic_expr_values(symint, shape_env) == {3} + assert evaluate_symbolic_expr_values(symint.node.expr, shape_env) == {3} + + +def test_evaluate_symbolic_expr_values_enumerates_non_singleton_symint() -> None: + shape_env, symint = _make_shape_env(hint=3, compiler_min=2, compiler_max=6) + + assert evaluate_symbolic_expr_values(symint, shape_env) == {2, 3, 4, 5, 6} + assert evaluate_symbolic_expr_values(symint.node.expr, shape_env) == {2, 3, 4, 5, 6} + + +def test_evaluate_symbolic_expr_values_tracks_exact_modulo_residue() -> None: + shape_env, symint = _make_shape_env(hint=3, compiler_min=2, compiler_max=6) + expr = sympy.Mod(16 * symint.node.expr - 7, 4) + + value_range = shape_env.bound_sympy(expr) + assert value_range.lower == 0 + assert value_range.upper == 3 + assert evaluate_symbolic_expr_values(expr, shape_env) == {1} + + +def test_evaluate_symbolic_expr_values_bails_out_for_large_symbol_ranges() -> None: + shape_env, symint = _make_shape_env(hint=3, compiler_min=1, compiler_max=400) + + assert evaluate_symbolic_expr_values(symint, shape_env) is None diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index 6b7990c0f2c..5375367b929 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -15,6 +15,8 @@ from executorch.backends.cadence.aot.quantizer.patterns import ( AddmmPattern, AddPattern, + AddReluPattern0, + AddReluPattern1, BmmPattern, CatPattern, Conv1dPattern, @@ -63,6 +65,7 @@ Conv2dReluPattern0, Conv2dReluPattern1, ) +AddReluPatterns = (AddReluPattern0, AddReluPattern1) def get_args_and_kwargs_add( @@ -616,7 +619,20 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 inputs_inputs + weights_inputs + other_inputs + bias_inputs ) kwargs = {} - if isinstance(pattern, AddPattern): + if isinstance(pattern, AddReluPatterns): + # For AddReLU, we are fusing Add+ReLU. + # The quantized_add op performs requantization, + # so the relu is implicit in the output quant params. + check_out_zero_point_is_min_range( + quant_node.args[2], quant_node.args[5] + ) + args, kwargs = get_args_and_kwargs_add( + graph_module, + inputs_inputs, + dequants_inputs, + quant_node, + ) + elif isinstance(pattern, AddPattern): args, kwargs = get_args_and_kwargs_add( graph_module, inputs_inputs, diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 2ce50871fc0..07aad18e36a 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -153,6 +153,61 @@ def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_add.per_tensor +# This is a base class for Add+ReLU fusion, since it can be used with two different relu aten ops +class AddReluBasePattern(QuantizationPattern): + @abstractmethod + def partition_types(self) -> List[OpOverload]: + pass + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> Tuple[PartitionAnchors, fx.Node]: + # The first node should be add, the second should be relu + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... + add_node = fused_partition[0].nodes[-1] + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... + relu_node = fused_partition[1].nodes[-1] + + # Bail if: + # - the add node is not a tensor add + # - the add node has kwargs (e.g. alpha) + is_tensor_add = isinstance(add_node.args[0], fx.Node) and isinstance( + add_node.args[1], fx.Node + ) + if not is_tensor_add or len(add_node.kwargs) > 0: + return ( + PartitionAnchors( + empty=True, + ), + add_node, + ) + + return ( + PartitionAnchors( + inputs=[(add_node, 0), (add_node, 1)], + weights=[], + biases=[], + output=[(relu_node,)], # Output is from the relu node + ), + relu_node, + ) + + def replacement_op(self) -> OpOverload: + return torch.ops.cadence.quantized_add.per_tensor + + +# Add + regular relu op fusion +class AddReluPattern0(AddReluBasePattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.add.Tensor, torch.ops.aten.relu.default] + + +# Add + alternate relu op fusion +class AddReluPattern1(AddReluBasePattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.add.Tensor, torch.ops.aten.relu_.default] + + class BmmPattern(QuantizationPattern): def partition_types(self) -> List[OpOverload]: return [torch.ops.aten.bmm.default] diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index 4edcd96e132..d521b9f83cf 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -13,6 +13,8 @@ from executorch.backends.cadence.aot.quantizer.patterns import ( AddmmPattern, AddPattern, + AddReluPattern0, + AddReluPattern1, BmmPattern, CatPattern, Conv1dPattern, @@ -398,6 +400,8 @@ def __init__( quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern1(), a8w8sym)) quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern0(), a8w8sym)) quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern1(), a8w8sym)) + quantizers.append(CadenceAtenQuantizer(AddReluPattern0(), a8w8)) + quantizers.append(CadenceAtenQuantizer(AddReluPattern1(), a8w8)) quantizers = quantizers + get_cadence_default_quantizers(is_qat=is_qat) quantizers.append(CadenceAtenQuantizer(AddPattern(), a8w8)) quantizers.append(CadenceAtenQuantizer(CatPattern(), a8w8)) diff --git a/backends/cadence/aot/tests/test_quantizer_ops.py b/backends/cadence/aot/tests/test_quantizer_ops.py index 06e2c08f4f4..dde26f06b7b 100644 --- a/backends/cadence/aot/tests/test_quantizer_ops.py +++ b/backends/cadence/aot/tests/test_quantizer_ops.py @@ -215,6 +215,15 @@ [qconfig_A8W8.input_activation], ), # CadenceFusedConvReluQuantizer test cases + ( + "fused_add_relu_A8W8", + lambda self: self._build_add_relu_graph(), + CadenceFusedConvReluQuantizer(), + torch.ops.aten.relu.default, + qconfig_A8W8.output_activation, + # For fused add+relu: both inputs are activations from add node + [qconfig_A8W8.input_activation, qconfig_A8W8.input_activation], + ), ( "fused_conv1d_relu_A8W8sym", lambda self: self._build_conv1d_relu_graph(), @@ -508,6 +517,50 @@ def _build_max_pool2d_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: ) return gm, max_pool_nodes[0] + def _build_add_relu_graph( + self, + ) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]: + """Build a graph with an add followed by relu (fused pattern). + + Returns: + A tuple of (graph_module, relu_node, add_node). + The relu_node is the target node where the annotation is placed. + The add_node is the input source node whose args contain the quantized inputs. + """ + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 10)) + y = builder.placeholder("y", torch.randn(1, 10)) + add = builder.call_operator( + op=torch.ops.aten.add.Tensor, + args=(x, y), + meta=NodeMetadata( + {"source_fn_stack": [("add", torch.ops.aten.add.Tensor)]} + ), + ) + relu = builder.call_operator( + op=torch.ops.aten.relu.default, + args=(add,), + meta=NodeMetadata( + {"source_fn_stack": [("relu", torch.ops.aten.relu.default)]} + ), + ) + builder.output([relu]) + gm = builder.get_graph_module() + + relu_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.relu.default, + ) + self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node") + + add_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.add.Tensor, + ) + self.assertEqual(len(add_nodes), 1, "Should find exactly one add node") + + return gm, relu_nodes[0], add_nodes[0] + def _build_conv2d_relu_graph( self, ) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]: diff --git a/backends/cortex_m/passes/__init__.py b/backends/cortex_m/passes/__init__.py index 3ef5fc02adb..19665f37083 100644 --- a/backends/cortex_m/passes/__init__.py +++ b/backends/cortex_m/passes/__init__.py @@ -8,6 +8,7 @@ from .convert_to_cortex_m_pass import ConvertToCortexMPass # noqa from .decompose_hardswish_pass import DecomposeHardswishPass # noqa from .decompose_mean_pass import DecomposeMeanPass # noqa +from .quantized_clamp_activation_pass import QuantizedClampActivationPass # noqa from .quantized_op_fusion_pass import QuantizedOpFusionPass # noqa from .replace_quant_nodes_pass import ReplaceQuantNodesPass # noqa from .cortex_m_pass_manager import CortexMPassManager # noqa # usort: skip diff --git a/backends/cortex_m/passes/activation_fusion_pass.py b/backends/cortex_m/passes/activation_fusion_pass.py index a53c065aaa4..ff61f3493dd 100644 --- a/backends/cortex_m/passes/activation_fusion_pass.py +++ b/backends/cortex_m/passes/activation_fusion_pass.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -8,7 +8,10 @@ import executorch.backends.cortex_m.ops.operators # noqa: F401 from executorch.backends.arm._passes.quant_args import QuantArgs -from executorch.backends.cortex_m.passes.passes_utils import quantize_val +from executorch.backends.cortex_m.passes.passes_utils import ( + get_activation_bounds, + quantize_val, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -23,7 +26,7 @@ class ActivationFusionPass(ExportPass): """Fuse activations into preceding Cortex-M quantized operators. Supported activation patterns: - q-> [conv2d, linear] -> [relu, hardtanh, hardsigmoid] -> dq + q-> [conv2d, linear, max_pool2d] -> [relu, hardtanh, hardsigmoid, clamp] -> dq Fusing works by clamping the quantized output range (and zero-point when required) of the preceding Cortex-M operator, then removing the activation @@ -37,10 +40,17 @@ class ActivationFusionPass(ExportPass): exir_ops.edge.aten.clamp.default, } + MAX_POOL_OPS = { + exir_ops.edge.aten.max_pool2d.default, + exir_ops.edge.aten.max_pool2d_with_indices.default, + } + FUSE_OPS = { exir_ops.edge.aten.linear.default, exir_ops.edge.aten.convolution.default, exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.max_pool2d.default, + exir_ops.edge.aten.max_pool2d_with_indices.default, } def _get_validated_qparams(self, node, input_node): @@ -63,30 +73,38 @@ def _get_validated_qparams(self, node, input_node): ) return None - match node.target: - case exir_ops.edge.aten.relu.default: - quantized_min_val = quantize_val(0, scale, zp, qmin, qmax) - quantized_max_val = qmax - case exir_ops.edge.aten.hardtanh.default: - quantized_min_val = quantize_val(node.args[1], scale, zp, qmin, qmax) - quantized_max_val = quantize_val(node.args[2], scale, zp, qmin, qmax) - case exir_ops.edge.aten.hardsigmoid.default: - quantized_min_val = quantize_val(0, scale, zp, qmin, qmax) - quantized_max_val = quantize_val(1, scale, zp, qmin, qmax) - case exir_ops.edge.aten.clamp.default: - quantized_min_val = ( - quantize_val(node.args[1], scale, zp, qmin, qmax) - if node.args[1] is not None - else qmin - ) - # Last arg is removed if none, so check length of args here - quantized_max_val = ( - quantize_val(node.args[2], scale, zp, qmin, qmax) - if len(node.args) == 3 - else qmax + bounds = get_activation_bounds(node) + if bounds is None: + logger.warning( + "Cannot fuse activation %s because bounds are not compile-time scalars.", + node.name, + ) + return None + min_val, max_val = bounds + + quantized_min_val = ( + quantize_val(min_val, scale, zp, qmin, qmax) + if min_val is not None + else qmin + ) + quantized_max_val = ( + quantize_val(max_val, scale, zp, qmin, qmax) + if max_val is not None + else qmax + ) + + if input_node.target in self.MAX_POOL_OPS: + if node.target == exir_ops.edge.aten.hardsigmoid.default: + logger.warning( + "Cannot fuse hardsigmoid %s after max_pool2d because max_pool2d requires matching input/output qparams.", + node.name, ) - case _: - raise RuntimeError(f"Unexpected target {node.target}.") + return None + # Max-pool keeps scale and zero-point unchanged and lowers fused + # activation bounds separately, so only qmin/qmax need updating here. + qparams_dict["qmin"] = int(quantized_min_val) + qparams_dict["qmax"] = int(quantized_max_val) + return qparams_dict # If the minimal quantized value is larger than the qmin, it means that the quantized range contains # invalid values [qmin, ..., quantized_min_val-1], indicating bad quantization parameters. diff --git a/backends/cortex_m/passes/cortex_m_pass_manager.py b/backends/cortex_m/passes/cortex_m_pass_manager.py index 9fef167ef09..074eb6118d0 100644 --- a/backends/cortex_m/passes/cortex_m_pass_manager.py +++ b/backends/cortex_m/passes/cortex_m_pass_manager.py @@ -28,6 +28,7 @@ from .convert_to_cortex_m_pass import ConvertToCortexMPass from .decompose_hardswish_pass import DecomposeHardswishPass from .decompose_mean_pass import DecomposeMeanPass +from .quantized_clamp_activation_pass import QuantizedClampActivationPass from .quantized_op_fusion_pass import QuantizedOpFusionPass from .replace_quant_nodes_pass import ReplaceQuantNodesPass @@ -42,6 +43,7 @@ class CortexMPassManager(PassManager): ReplaceScalarWithTensorArgPass, ReplaceQuantNodesPass, ActivationFusionPass, + QuantizedClampActivationPass, DecomposeHardswishPass, QuantizedOpFusionPass, ConvertToCortexMPass, diff --git a/backends/cortex_m/passes/passes_utils.py b/backends/cortex_m/passes/passes_utils.py index a6f68022430..fcbfa301b06 100644 --- a/backends/cortex_m/passes/passes_utils.py +++ b/backends/cortex_m/passes/passes_utils.py @@ -6,6 +6,7 @@ # LICENSE file in the root directory of this source tree. import math +from typing import Any import torch @@ -21,6 +22,56 @@ def quantize_val(val, scale, zp, qmin, qmax): return float(min(max(torch.round(torch.Tensor([val / scale + zp])), qmin), qmax)) +def extract_constant_scalar(arg: Any) -> float | None: + if arg is None: + return None + if isinstance(arg, (int, float)): + return float(arg) + if isinstance(arg, Node): + if arg.op == "call_function" and arg.target in { + exir_ops.edge.aten.full_like.default, + exir_ops.edge.aten.full.default, + torch.ops.aten.full_like.default, + torch.ops.aten.full.default, + }: + fill_arg = arg.args[1] if len(arg.args) > 1 else None + return extract_constant_scalar(fill_arg) + val = arg.meta.get("val") + if val is None: + return None + return extract_constant_scalar(val) + return None + + +def get_activation_bounds(node: Node) -> tuple[float | None, float | None] | None: + bounds: tuple[float | None, float | None] + match node.target: + case exir_ops.edge.aten.relu.default | exir_ops.edge.aten.relu_.default: + bounds = (0.0, None) + case exir_ops.edge.aten.hardsigmoid.default: + bounds = (0.0, 1.0) + case exir_ops.edge.aten.hardtanh.default | exir_ops.edge.aten.hardtanh_.default: + bounds = ( + extract_constant_scalar(node.args[1]), + extract_constant_scalar(node.args[2]), + ) + case exir_ops.edge.aten.clamp.default | exir_ops.edge.aten.clamp.Tensor: + bounds = ( + extract_constant_scalar(node.args[1]) if len(node.args) > 1 else None, + extract_constant_scalar(node.args[2]) if len(node.args) > 2 else None, + ) + case _: + return None + + min_val, max_val = bounds + if len(node.args) > 1 and min_val is None and node.args[1] is not None: + return None + if len(node.args) > 2 and max_val is None and node.args[2] is not None: + return None + + return bounds + + def dequantize_per_tensor_cmsis( qtensor: torch.Tensor, zero_point: int, multiplier: int, shift: int ) -> torch.Tensor: diff --git a/backends/cortex_m/passes/quantized_clamp_activation_pass.py b/backends/cortex_m/passes/quantized_clamp_activation_pass.py new file mode 100644 index 00000000000..2ba003dbc01 --- /dev/null +++ b/backends/cortex_m/passes/quantized_clamp_activation_pass.py @@ -0,0 +1,129 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Any + +import torch +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_output_qparams, +) +from executorch.backends.cortex_m.passes.passes_utils import ( + get_activation_bounds, + quantize_val, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_manager import PassResult + +logger = logging.getLogger(__name__) + + +class QuantizedClampActivationPass(ExportPass): + """Canonicalize remaining clamp-like activations on quantized tensors. + + This pass runs after activation fusion, so any remaining relu/hardtanh/clamp + still needs to execute in the quantized domain. It rewrites relu and + hardtanh variants to `aten.clamp.default` and quantizes the clamp bounds so + the portable kernel consumes and produces int8 tensors. + """ + + TARGETS = { + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.relu_.default, + exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.hardtanh_.default, + exir_ops.edge.aten.clamp.default, + exir_ops.edge.aten.clamp.Tensor, + } + + def _get_quantized_bounds( + self, node: Node, qparams_dict: dict[str, Any] + ) -> tuple[int | None, int | None] | None: + qmin = qparams_dict["qmin"] + qmax = qparams_dict["qmax"] + scale = qparams_dict["scale"] + zp = qparams_dict["zp"] + + bounds = get_activation_bounds(node) + if bounds is None: + logger.warning( + "Cannot rewrite %s because bounds are not compile-time scalars.", + node.name, + ) + return None + min_val, max_val = bounds + + quantized_min = ( + int(quantize_val(min_val, scale, zp, qmin, qmax)) + if min_val is not None + else None + ) + quantized_max = ( + int(quantize_val(max_val, scale, zp, qmin, qmax)) + if max_val is not None + else None + ) + return quantized_min, quantized_max + + def _is_quantized_int8_activation(self, node: Node) -> bool: + input_node = node.args[0] if len(node.args) > 0 else None + if not isinstance(input_node, Node): + return False + try: + tensor = get_first_fake_tensor(input_node) + except Exception: + return False + if tensor is None or tensor.dtype != torch.int8: + return False + + try: + qparams_dict = get_output_qparams(node)[0]._asdict() + except (ValueError, KeyError): + logger.warning( + "Cannot quantize clamp bounds for %s without output qparams.", + node.name, + ) + return False + + scale = qparams_dict["scale"] + zp = qparams_dict["zp"] + if not isinstance(scale, float) or not isinstance(zp, int): + logger.warning( + "Cannot quantize clamp bounds for %s with non per-tensor qparams.", + node.name, + ) + return False + + return True + + def call(self, graph_module: GraphModule) -> PassResult: + modified = False + + for node in list(graph_module.graph.nodes): + if node.op != "call_function" or node.target not in self.TARGETS: + continue + if not self._is_quantized_int8_activation(node): + continue + + qparams_dict = get_output_qparams(node)[0]._asdict() + + quantized_bounds = self._get_quantized_bounds(node, qparams_dict) + if quantized_bounds is None: + continue + + quantized_min, quantized_max = quantized_bounds + node.target = exir_ops.edge.aten.clamp.default + node.args = (node.args[0], quantized_min, quantized_max) + modified = True + + if modified: + graph_module = super().call(graph_module).graph_module + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + + return PassResult(graph_module, modified) diff --git a/backends/cortex_m/quantizer/quantization_configs.py b/backends/cortex_m/quantizer/quantization_configs.py index 9bc13c05e9d..0f10bd6afef 100644 --- a/backends/cortex_m/quantizer/quantization_configs.py +++ b/backends/cortex_m/quantizer/quantization_configs.py @@ -2,6 +2,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import operator from typing import Any, Callable import torch @@ -86,10 +87,45 @@ torch.ops.aten.max_pool2d_with_indices.default, } +POOL_FUSED_ACTIVATION_TARGETS = { + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, + torch.ops.aten.clamp.default, + torch.ops.aten.clamp_.default, +} + class CortexMQuantizationConfig(QuantizationConfig): """Configures quantization, while enforcing cortex-m specific constraints.""" + @staticmethod + def _get_shared_pool_input(node: Node | None) -> Node | None: + if node is None or len(node.args) == 0: + return None + + input_node = node.args[0] + if not isinstance(input_node, Node): + return None + + if input_node.target in POOL_SHARE_OUTPUT_TARGETS: + if len(input_node.args) > 0 and isinstance(input_node.args[0], Node): + return input_node.args[0] + return None + + if input_node.target == operator.getitem and len(input_node.args) > 0: + pool_node = input_node.args[0] + if ( + isinstance(pool_node, Node) + and pool_node.target in POOL_SHARE_OUTPUT_TARGETS + and len(pool_node.args) > 0 + and isinstance(pool_node.args[0], Node) + ): + return pool_node.args[0] + + return None + def get_input_act_qspec( self, node: Node | None = None, input_node: Node | None = None ) -> QuantizationSpecBase | None: @@ -117,6 +153,10 @@ def get_output_act_qspec( if isinstance(input_node, Node): return SharedQuantizationSpec((input_node, node)) return super().get_output_act_qspec() + if node is not None and node.target in POOL_FUSED_ACTIVATION_TARGETS: + shared_pool_input = self._get_shared_pool_input(node) + if shared_pool_input is not None: + return SharedQuantizationSpec(shared_pool_input) return super().get_output_act_qspec() def get_weight_qspec(self, node: Node | None = None) -> QuantizationSpecBase | None: diff --git a/backends/cortex_m/quantizer/quantizer_support.py b/backends/cortex_m/quantizer/quantizer_support.py index 2cf0483f74b..3dfbb67638a 100644 --- a/backends/cortex_m/quantizer/quantizer_support.py +++ b/backends/cortex_m/quantizer/quantizer_support.py @@ -122,7 +122,31 @@ POOL_OP_PATTERNS = { (torch.ops.aten.avg_pool2d.default,): CortexMAvgPool2DCheck, (torch.ops.aten.max_pool2d.default,): CortexMMaxPool2DCheck, + ( + torch.ops.aten.max_pool2d.default, + torch.ops.aten.relu.default, + ): CortexMMaxPool2DCheck, + ( + torch.ops.aten.max_pool2d.default, + torch.ops.aten.hardtanh.default, + ): CortexMMaxPool2DCheck, + ( + torch.ops.aten.max_pool2d.default, + torch.ops.aten.clamp.default, + ): CortexMMaxPool2DCheck, (torch.ops.aten.max_pool2d_with_indices.default,): CortexMMaxPool2DCheck, + ( + torch.ops.aten.max_pool2d_with_indices.default, + torch.ops.aten.relu.default, + ): CortexMMaxPool2DCheck, + ( + torch.ops.aten.max_pool2d_with_indices.default, + torch.ops.aten.hardtanh.default, + ): CortexMMaxPool2DCheck, + ( + torch.ops.aten.max_pool2d_with_indices.default, + torch.ops.aten.clamp.default, + ): CortexMMaxPool2DCheck, } BMM_OP_PATTERNS = { diff --git a/backends/cortex_m/test/build_test_runner.sh b/backends/cortex_m/test/build_test_runner.sh index 6ac9aa55e73..2505f83c9da 100755 --- a/backends/cortex_m/test/build_test_runner.sh +++ b/backends/cortex_m/test/build_test_runner.sh @@ -21,6 +21,7 @@ build_root_test_dir="${et_root_dir}/arm_test/arm_semihosting_executor_runner_cor select_ops_list="\ aten::add.out,\ +aten::clamp.out,\ aten::mul.out,\ aten::convolution.out,\ dim_order_ops::_clone_dim_order.out,\ diff --git a/backends/cortex_m/test/misc/test_portable_int8.py b/backends/cortex_m/test/misc/test_portable_int8.py index 82b719230eb..4e3b5f41561 100644 --- a/backends/cortex_m/test/misc/test_portable_int8.py +++ b/backends/cortex_m/test/misc/test_portable_int8.py @@ -662,12 +662,6 @@ def _quantize_and_export( xfails: dict[str, xfail_type] = { "contiguous": "MLETORCH-1863: Contiguos no-op is removed in to-edge, leading to unnecessary Q-DQ-Q-DQ chain.", - "clamp": "MLETORCH-1864: Support non-fused clamp-type activations.", - "clamp_tensor": "MLETORCH-1864: Support non-fused clamp-type activations.", - "hardtanh": "MLETORCH-1864: Support non-fused clamp-type activations.", - "hardtanh_": "MLETORCH-1864: Support non-fused clamp-type activations.", - "relu": "MLETORCH-1864: Support non-fused clamp-type activations.", - "relu_": "MLETORCH-1864: Support non-fused clamp-type activations.", "eq_scalar": "MLETORCH-1865: Properly support flaky scalar comparison ops.", "ne_scalar": "MLETORCH-1865: Properly support flaky scalar comparison ops.", "ge_scalar": "MLETORCH-1865: Properly support flaky scalar comparison ops.", diff --git a/backends/cortex_m/test/models/test_nn_modules.py b/backends/cortex_m/test/models/test_nn_modules.py index 4a92fd578ff..303b481d4bc 100644 --- a/backends/cortex_m/test/models/test_nn_modules.py +++ b/backends/cortex_m/test/models/test_nn_modules.py @@ -1,6 +1,6 @@ -# Copyright 2025-2026 Arm Limited and/or its affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. diff --git a/backends/cortex_m/test/ops/test_activation.py b/backends/cortex_m/test/ops/test_activation.py index 8886a05a84b..0934386d67c 100644 --- a/backends/cortex_m/test/ops/test_activation.py +++ b/backends/cortex_m/test/ops/test_activation.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -398,6 +398,154 @@ def forward(self, x): return torch.clamp(self.linear(x), min=None, max=6.0) +class CortexMStandaloneReLU(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_relu_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_aten_clamp_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def forward(self, x): + return torch.relu(x) + + +class CortexMStandaloneHardtanh(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_aten_clamp_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def forward(self, x): + return torch.nn.functional.hardtanh(x, -1.0, 1.0) + + +class CortexMStandaloneClamp(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_clamp_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_aten_clamp_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def forward(self, x): + return torch.clamp(x, -1.0, 1.0) + + +class CortexMStandaloneClampTensor(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_clamp_Tensor": 1, + "executorch_exir_dialects_edge__ops_aten_full_like_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 4, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 5, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_aten_clamp_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def forward(self, x): + return torch.ops.aten.clamp.Tensor( + x, torch.full_like(x, -1.0), torch.full_like(x, 1.0) + ) + + +class CortexMMaxPool2DReLU(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1, + "executorch_exir_dialects_edge__ops_aten_relu_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_max_pool2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + ops_after_absent = ["executorch_exir_dialects_edge__ops_aten_relu_default"] + + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(self.pool(x)) + + +class CortexMMaxPool2DHardtanh(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_max_pool2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + ops_after_absent = ["executorch_exir_dialects_edge__ops_aten_hardtanh_default"] + + def __init__(self, min_val=-0.5, max_val=0.5): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.min_val = min_val + self.max_val = max_val + + def forward(self, x): + return torch.nn.functional.hardtanh(self.pool(x), self.min_val, self.max_val) + + +class CortexMMaxPool2DClamp(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1, + "executorch_exir_dialects_edge__ops_aten_clamp_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_max_pool2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + ops_after_absent = ["executorch_exir_dialects_edge__ops_aten_clamp_default"] + + def __init__(self, min_val=-0.25, max_val=0.75): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.min_val = min_val + self.max_val = max_val + + def forward(self, x): + return torch.clamp(self.pool(x), self.min_val, self.max_val) + + test_cases = { # Linear + activation tests with various data ranges "linear_relu_small_range": McuTestCase( @@ -509,6 +657,40 @@ def forward(self, x): model=CortexMLinearClamp(in_features=4, out_features=3), example_inputs=(ramp_tensor(-10, 10, (1, 4)),), ), + "standalone_relu": McuTestCase( + model=CortexMStandaloneReLU(), + example_inputs=(ramp_tensor(-5, 5, (2, 3, 4, 5)),), + ), + "standalone_hardtanh": McuTestCase( + model=CortexMStandaloneHardtanh(), + example_inputs=(ramp_tensor(-5, 5, (2, 3, 4, 5)),), + ), + "standalone_clamp": McuTestCase( + model=CortexMStandaloneClamp(), + example_inputs=(ramp_tensor(-5, 5, (2, 3, 4, 5)),), + ), + "standalone_clamp_tensor": McuTestCase( + model=CortexMStandaloneClampTensor(), + example_inputs=(ramp_tensor(-5, 5, (2, 3, 4, 5)),), + ), + "maxpool_relu": McuTestCase( + model=CortexMMaxPool2DReLU(), + example_inputs=( + ramp_tensor(-10, 10, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "maxpool_hardtanh": McuTestCase( + model=CortexMMaxPool2DHardtanh(), + example_inputs=( + ramp_tensor(-10, 10, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "maxpool_clamp": McuTestCase( + model=CortexMMaxPool2DClamp(), + example_inputs=( + ramp_tensor(-10, 10, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), } @@ -520,6 +702,8 @@ def test_dialect_activation(test_case): test_case.model.ops_after_transforms, qtol=1, ) + if hasattr(test_case.model, "ops_after_absent"): + tester.check_not(test_case.model.ops_after_absent) @parametrize("test_case", test_cases) diff --git a/backends/cortex_m/test/ops/test_conv_transpose.py b/backends/cortex_m/test/ops/test_conv_transpose.py index 7a91c5e1b6b..8202e3dc999 100644 --- a/backends/cortex_m/test/ops/test_conv_transpose.py +++ b/backends/cortex_m/test/ops/test_conv_transpose.py @@ -60,6 +60,61 @@ def forward(self, x): return self.conv_transpose(x) +class CortexMConvTranspose2DReLU(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_aten_relu_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_transpose_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + ops_after_absent = ["executorch_exir_dialects_edge__ops_aten_relu_default"] + + def __init__(self): + super().__init__() + self.conv_transpose = torch.nn.ConvTranspose2d( + 4, 2, kernel_size=3, stride=2, padding=1, bias=True + ) + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(self.conv_transpose(x)) + + +class CortexMConvTranspose2DHardtanh(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_transpose_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + ops_after_absent = ["executorch_exir_dialects_edge__ops_aten_hardtanh_default"] + + def __init__(self): + super().__init__() + self.conv_transpose = torch.nn.ConvTranspose2d( + 4, 2, kernel_size=3, stride=2, padding=1, bias=True + ) + + def forward(self, x): + return torch.nn.functional.hardtanh(self.conv_transpose(x), -0.5, 0.5) + + # Test cases covering various configurations test_cases = { # Basic test case @@ -123,6 +178,18 @@ def forward(self, x): ramp_tensor(0, 50, (1, 5, 4, 4)).to(memory_format=torch.channels_last), ), ), + "conv_transpose2d_relu": McuTestCase( + model=CortexMConvTranspose2DReLU(), + example_inputs=( + ramp_tensor(-10, 10, (1, 4, 4, 4)).to(memory_format=torch.channels_last), + ), + ), + "conv_transpose2d_hardtanh": McuTestCase( + model=CortexMConvTranspose2DHardtanh(), + example_inputs=( + ramp_tensor(-10, 10, (1, 4, 4, 4)).to(memory_format=torch.channels_last), + ), + ), # Dilation variation "conv_transpose2d_dilation_2": McuTestCase( model=CortexMConvTranspose2D(2, 4, kernel_size=3, dilation=2), @@ -244,12 +311,14 @@ def test_dialect_conv_transpose2d(test_case): test_case.model.ops_after_transforms, qtol=1, ) + if hasattr(test_case.model, "ops_after_absent"): + tester.check_not(test_case.model.ops_after_absent) -# Implementation xfails: empty because unsupported configurations are now -# rejected at AOT time by the quantizer filter, so they fall back to portable -# ops and work correctly. Only xfails_dialect needs to track these. -xfails_implementation: dict[str, xfail_type] = {} +xfails_implementation: dict[str, xfail_type] = { + "conv_transpose2d_relu": "Fused transpose-conv + relu lowers correctly but current implementation is numerically incorrect.", + "conv_transpose2d_hardtanh": "Fused transpose-conv + hardtanh lowers correctly but current implementation is numerically incorrect.", +} @parametrize("test_case", test_cases, xfails=xfails_implementation) diff --git a/backends/cuda/runtime/shims/sort.cu b/backends/cuda/runtime/shims/sort.cu index 804b5a55959..8d4a9771e62 100644 --- a/backends/cuda/runtime/shims/sort.cu +++ b/backends/cuda/runtime/shims/sort.cu @@ -24,8 +24,8 @@ namespace executorch::backends::cuda { namespace c10_slim = executorch::backends::aoti::slim::c10; -// PyTorch ScalarType::Half = 5, not defined in slim ScalarType enum. -constexpr auto kHalf = static_cast(5); +// PyTorch ScalarType::Half = 5, now defined in slim ScalarType enum. +using c10_slim::kHalf; namespace { @@ -188,7 +188,7 @@ AOTITorchError aoti_torch_cuda_sort_stable( case c10_slim::ScalarType::BFloat16: elem_size = sizeof(__nv_bfloat16); break; - case kHalf: + case c10_slim::ScalarType::Half: elem_size = sizeof(__half); break; default: @@ -387,7 +387,7 @@ AOTITorchError aoti_torch_cuda_sort_stable( stream); break; } - case kHalf: { + case c10_slim::ScalarType::Half: { sort_slice_impl( static_cast<__half*>(values_base) + offset, idx_ptr, diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 352d7af5a14..103bdeb6b82 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -319,6 +319,15 @@ Error defineTensor( ET_CHECK_OR_RETURN_ERROR( tensor_value != nullptr, InvalidProgram, "Deserialized tensor is null"); + // Validate that tensor_value->flags() is a subset of the allowed flags. + constexpr uint32_t kAllowedFlagsMask = + XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT; + ET_CHECK_OR_RETURN_ERROR( + (tensor_value->flags() & ~kAllowedFlagsMask) == 0, + InvalidProgram, + "Tensor value has unsupported flag bits 0x%x", + tensor_value->flags()); + // Get tensor dims, here we need to use a vector in order to properly // convert the uint32_t* to size_t*. Scalar tensors (rank 0) are permitted // to have a null dims vector; in that case dims_data is empty. diff --git a/docs/source/llm/run-with-c-plus-plus.md b/docs/source/llm/run-with-c-plus-plus.md index 217afad847b..b6c6082c3a6 100644 --- a/docs/source/llm/run-with-c-plus-plus.md +++ b/docs/source/llm/run-with-c-plus-plus.md @@ -183,13 +183,13 @@ struct GenerationConfig { int32_t num_eos = 0; // Number of EOS tokens to add // Helper method to resolve the actual max_new_tokens based on constraints - int32_t resolve_max_new_tokens(int32_t max_context_len, int32_t num_prompt_tokens) const; + int32_t resolve_max_new_tokens(int64_t max_context_len, int64_t num_tokens_occupied) const; }; ``` The `resolve_max_new_tokens` method handles the logic of determining how many tokens can be generated based on: - The model's maximum context length -- The number of tokens in the prompt +- The number of token positions already occupied in the context window - The user-specified maximum sequence length and maximum new tokens ### Implementation Components diff --git a/examples/apple/coreml/llama/run_static_llm_multifunction.py b/examples/apple/coreml/llama/run_static_llm_multifunction.py index 517c54435f4..98d0cb0a763 100644 --- a/examples/apple/coreml/llama/run_static_llm_multifunction.py +++ b/examples/apple/coreml/llama/run_static_llm_multifunction.py @@ -22,14 +22,16 @@ import argparse import json import time -from typing import Any, Dict, List, Tuple +from typing import List import torch -import torch.utils._pytree as pytree +from executorch.examples.apple.coreml.llama.utils import ( + create_pte_wrapper, + setup_multifunction_managers, +) from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.runner.generation import next_token -from executorch.examples.models.llama.static_attention import StaticAttentionIOManager from executorch.runtime import Runtime from pytorch_tokenizers import get_tokenizer @@ -41,170 +43,6 @@ def get_stop_tokens(tokenizer) -> List[int]: return [tokenizer.eos_id] -def create_pte_wrapper( - decode_method, - prefill_method, - mgr: "StaticAttentionIOManager", - prefill_seq_len: int, - prefill_mask: Dict[str, torch.Tensor], -): - """ - Create a wrapper function that adapts PTE execution to the interface - expected by StaticAttentionIOManager. - - This multifunction version selects between prefill and decode methods - based on the input sequence length. Both methods use the SAME cache_len, - so the cache buffer is shared directly without any slicing or copying. - - The wrapper: - - Takes (tokens, options_dict) like the eager model - - Selects prefill or decode method based on token count - - Uses the same cache buffer for both methods (no slicing needed) - - Flattens inputs using pytree - - Executes the appropriate PTE method - - Reconstructs outputs to match eager model format: (logits, {"out_cache_state": (k_dict, v_dict)}) - - Args: - decode_method: The PTE method for decode (seqlen=1) - prefill_method: The PTE method for prefill (seqlen=input_len) - mgr: StaticAttentionIOManager with caches sized for shared cache_len - prefill_seq_len: The sequence length for prefill - prefill_mask: Pre-computed mask tensor for prefill method - """ - - k_cache_keys = list(mgr.k_caches.keys()) - v_cache_keys = list(mgr.v_caches.keys()) - - timing_stats = { - "flatten_time": 0.0, - "execute_time": 0.0, - "reconstruct_time": 0.0, - "detection_time": 0.0, - "options_build_time": 0.0, - "call_count": 0, - } - - def wrapper( - tokens: torch.Tensor, options: Dict[str, Any] - ) -> Tuple[torch.Tensor, Dict[str, Any]]: - import time as time_module - - timing_stats["call_count"] += 1 - - t0 = time_module.perf_counter() - - # Detect actual sequence length. - # StaticAttentionIOManager._run_once pads tokens with zeros on the right. - # For decode (1 actual token), positions 1+ are all zeros. - padded_seq_len = tokens.shape[1] - if padded_seq_len > 1 and (tokens[0, 1:] == 0).all(): - actual_seq_len = 1 - else: - actual_seq_len = padded_seq_len - - is_prefill = actual_seq_len == prefill_seq_len - - t1 = time_module.perf_counter() - timing_stats["detection_time"] += t1 - t0 - - t0 = time_module.perf_counter() - - # Get the input cache state from options - in_k_caches, in_v_caches = options["in_cache_state"] - - # Both prefill and decode use the same cache_len, so no slicing needed! - # Just select the appropriate method and mask. - if is_prefill: - method = prefill_method - adapted_mask = prefill_mask - else: - method = decode_method - adapted_mask = mgr.masks - - adapted_options = { - "masks": adapted_mask, - "freqs_cos_override": options["freqs_cos_override"], - "freqs_sin_override": options["freqs_sin_override"], - "in_cache_state": (in_k_caches, in_v_caches), # Same cache for both! - } - - if "last_valid_token_pos" in options: - adapted_options["last_valid_token_pos"] = options["last_valid_token_pos"] - - inputs = (tokens, adapted_options) - - t1 = time_module.perf_counter() - timing_stats["options_build_time"] += t1 - t0 - - t0 = time_module.perf_counter() - flat_inputs, _ = pytree.tree_flatten(inputs) - t1 = time_module.perf_counter() - timing_stats["flatten_time"] += t1 - t0 - - t0 = time_module.perf_counter() - outputs = method.execute(flat_inputs) - t1 = time_module.perf_counter() - timing_stats["execute_time"] += t1 - t0 - - t0 = time_module.perf_counter() - - logits = outputs[0] - - num_layers = len(k_cache_keys) - k_updates = outputs[1 : 1 + num_layers] - v_updates = outputs[1 + num_layers : 1 + 2 * num_layers] - - k_cache_dict = dict(zip(k_cache_keys, k_updates)) - v_cache_dict = dict(zip(v_cache_keys, v_updates)) - - attn_updates = {"out_cache_state": (k_cache_dict, v_cache_dict)} - - t1 = time_module.perf_counter() - timing_stats["reconstruct_time"] += t1 - t0 - - return logits, attn_updates - - def print_timing_stats(): - n = timing_stats["call_count"] - if n > 0: - print(f"\n=== Wrapper Timing Stats ({n} calls) ===") - print( - f" Detection time: {timing_stats['detection_time']*1000:.2f}ms total, {timing_stats['detection_time']/n*1000:.4f}ms avg" - ) - print( - f" Options build: {timing_stats['options_build_time']*1000:.2f}ms total, {timing_stats['options_build_time']/n*1000:.4f}ms avg" - ) - print( - f" Flatten time: {timing_stats['flatten_time']*1000:.2f}ms total, {timing_stats['flatten_time']/n*1000:.4f}ms avg" - ) - print( - f" Execute time: {timing_stats['execute_time']*1000:.2f}ms total, {timing_stats['execute_time']/n*1000:.3f}ms avg" - ) - print( - f" Reconstruct time: {timing_stats['reconstruct_time']*1000:.2f}ms total, {timing_stats['reconstruct_time']/n*1000:.4f}ms avg" - ) - total = ( - timing_stats["detection_time"] - + timing_stats["options_build_time"] - + timing_stats["flatten_time"] - + timing_stats["execute_time"] - + timing_stats["reconstruct_time"] - ) - print( - f" Total wrapper: {total*1000:.2f}ms total, {total/n*1000:.3f}ms avg" - ) - print( - f" Execute is {timing_stats['execute_time']/total*100:.1f}% of wrapper time" - ) - expected_tps = 1000 / (timing_stats["execute_time"] / n * 1000) - print(f" Expected tok/s from execute alone: {expected_tps:.1f}") - - wrapper.print_timing_stats = print_timing_stats - wrapper.timing_stats = timing_stats - - return wrapper - - def main(): parser = argparse.ArgumentParser( description="Run multifunction static attention Llama model" @@ -326,36 +164,16 @@ def main(): print(f"Prefill: input_len={prefill_input_len}, cache_len={shared_cache_len}") print(f"Decode: input_len={decode_input_len}, cache_len={shared_cache_len}") - # Create decode manager (input_len=1) - used for decode phase - mgr = StaticAttentionIOManager( - model_args, - input_len=decode_input_len, - cache_lens=shared_cache_len, - batch_size=1, - dtype=torch.float16, - style="smart_mask", - mask_val=float("-inf"), - ) - - # Create prefill manager (input_len=64) with the SAME cache_len. - # Since both use the same cache_len, we can share the cache buffer directly. - prefill_mgr = StaticAttentionIOManager( + # Create managers with shared cache buffers + mgr, prefill_mgr, prefill_mask = setup_multifunction_managers( model_args, - input_len=prefill_input_len, - cache_lens=shared_cache_len, # Same cache_len as decode! - batch_size=1, + prefill_input_len, + decode_input_len, + shared_cache_len, dtype=torch.float16, - style="smart_mask", mask_val=float("-inf"), ) - # Share cache buffers: point prefill_mgr's caches to mgr's caches. - # No copying needed since both managers use the same cache_len! - prefill_mgr.k_caches = mgr.k_caches - prefill_mgr.v_caches = mgr.v_caches - - prefill_mask = prefill_mgr.masks - # Load PTE model with multifunction support print(f"Loading multifunction model from {args.model}...") runtime = Runtime.get() diff --git a/examples/apple/coreml/llama/utils.py b/examples/apple/coreml/llama/utils.py index 1e5a842fed5..755a654b9df 100644 --- a/examples/apple/coreml/llama/utils.py +++ b/examples/apple/coreml/llama/utils.py @@ -4,7 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import time +from typing import Any, Dict, Tuple, TYPE_CHECKING + import torch +import torch.utils._pytree as pytree + +if TYPE_CHECKING: + from executorch.examples.models.llama.static_attention import ( + StaticAttentionIOManager, + ) class SplitLinearModule(torch.nn.Module): @@ -114,3 +123,212 @@ def replace_linear_with_split_linear( in_target_split_size, in_max_splits, ) + + +def setup_multifunction_managers( + config, + prefill_input_len: int, + decode_input_len: int, + shared_cache_len: int, + dtype: torch.dtype = torch.float16, + mask_val: float = float("-inf"), + style: str = "smart_mask", +): + """ + Create prefill and decode StaticAttentionIOManager instances with shared cache buffers. + + Both managers use the same cache_len so they share cache memory directly. + Returns (decode_mgr, prefill_mgr, prefill_mask). + """ + from executorch.examples.models.llama.static_attention import ( + StaticAttentionIOManager, + ) + + mgr = StaticAttentionIOManager( + config, + input_len=decode_input_len, + cache_lens=shared_cache_len, + batch_size=1, + dtype=dtype, + style=style, + mask_val=mask_val, + ) + + prefill_mgr = StaticAttentionIOManager( + config, + input_len=prefill_input_len, + cache_lens=shared_cache_len, + batch_size=1, + dtype=dtype, + style=style, + mask_val=mask_val, + ) + + # Share cache buffers — no copying needed + prefill_mgr.k_caches = mgr.k_caches + prefill_mgr.v_caches = mgr.v_caches + prefill_mask = prefill_mgr.masks + + return mgr, prefill_mgr, prefill_mask + + +def create_pte_wrapper( + decode_method, + prefill_method, + mgr: "StaticAttentionIOManager", + prefill_seq_len: int, + prefill_mask: Dict[str, torch.Tensor], +): + """ + Create a wrapper function that adapts PTE execution to the interface + expected by StaticAttentionIOManager. + + This multifunction version selects between prefill and decode methods + based on the input sequence length. Both methods use the SAME cache_len, + so the cache buffer is shared directly without any slicing or copying. + + The wrapper: + - Takes (tokens, options_dict) like the eager model + - Selects prefill or decode method based on token count + - Uses the same cache buffer for both methods (no slicing needed) + - Flattens inputs using pytree + - Executes the appropriate PTE method + - Reconstructs outputs to match eager model format: (logits, {"out_cache_state": (k_dict, v_dict)}) + + Args: + decode_method: The PTE method for decode (seqlen=1) + prefill_method: The PTE method for prefill (seqlen=input_len) + mgr: StaticAttentionIOManager with caches sized for shared cache_len + prefill_seq_len: The sequence length for prefill + prefill_mask: Pre-computed mask tensor for prefill method + """ + + k_cache_keys = list(mgr.k_caches.keys()) + v_cache_keys = list(mgr.v_caches.keys()) + + timing_stats = { + "flatten_time": 0.0, + "execute_time": 0.0, + "reconstruct_time": 0.0, + "detection_time": 0.0, + "options_build_time": 0.0, + "call_count": 0, + } + + def wrapper( + tokens: torch.Tensor, options: Dict[str, Any] + ) -> Tuple[torch.Tensor, Dict[str, Any]]: + timing_stats["call_count"] += 1 + + t0 = time.perf_counter() + + # Detect actual sequence length. + # StaticAttentionIOManager._run_once pads tokens with zeros on the right. + # For decode (1 actual token), positions 1+ are all zeros. + padded_seq_len = tokens.shape[1] + if padded_seq_len > 1 and (tokens[0, 1:] == 0).all(): + actual_seq_len = 1 + else: + actual_seq_len = padded_seq_len + + is_prefill = actual_seq_len == prefill_seq_len + + t1 = time.perf_counter() + timing_stats["detection_time"] += t1 - t0 + + t0 = time.perf_counter() + + # Get the input cache state from options + in_k_caches, in_v_caches = options["in_cache_state"] + + # Both prefill and decode use the same cache_len, so no slicing needed! + # Just select the appropriate method and mask. + if is_prefill: + method = prefill_method + adapted_mask = prefill_mask + else: + method = decode_method + adapted_mask = mgr.masks + + adapted_options = { + "masks": adapted_mask, + "freqs_cos_override": options["freqs_cos_override"], + "freqs_sin_override": options["freqs_sin_override"], + "in_cache_state": (in_k_caches, in_v_caches), # Same cache for both! + } + + if "last_valid_token_pos" in options: + adapted_options["last_valid_token_pos"] = options["last_valid_token_pos"] + + inputs = (tokens, adapted_options) + + t1 = time.perf_counter() + timing_stats["options_build_time"] += t1 - t0 + + t0 = time.perf_counter() + flat_inputs, _ = pytree.tree_flatten(inputs) + t1 = time.perf_counter() + timing_stats["flatten_time"] += t1 - t0 + + t0 = time.perf_counter() + outputs = method.execute(flat_inputs) + t1 = time.perf_counter() + timing_stats["execute_time"] += t1 - t0 + + t0 = time.perf_counter() + + logits = outputs[0] + + num_layers = len(k_cache_keys) + k_updates = outputs[1 : 1 + num_layers] + v_updates = outputs[1 + num_layers : 1 + 2 * num_layers] + + k_cache_dict = dict(zip(k_cache_keys, k_updates)) + v_cache_dict = dict(zip(v_cache_keys, v_updates)) + + attn_updates = {"out_cache_state": (k_cache_dict, v_cache_dict)} + + t1 = time.perf_counter() + timing_stats["reconstruct_time"] += t1 - t0 + + return logits, attn_updates + + def print_timing_stats(): + n = timing_stats["call_count"] + if n > 0: + print(f"\n=== Wrapper Timing Stats ({n} calls) ===") + print( + f" Detection time: {timing_stats['detection_time']*1000:.2f}ms total, {timing_stats['detection_time']/n*1000:.4f}ms avg" + ) + print( + f" Options build: {timing_stats['options_build_time']*1000:.2f}ms total, {timing_stats['options_build_time']/n*1000:.4f}ms avg" + ) + print( + f" Flatten time: {timing_stats['flatten_time']*1000:.2f}ms total, {timing_stats['flatten_time']/n*1000:.4f}ms avg" + ) + print( + f" Execute time: {timing_stats['execute_time']*1000:.2f}ms total, {timing_stats['execute_time']/n*1000:.3f}ms avg" + ) + print( + f" Reconstruct time: {timing_stats['reconstruct_time']*1000:.2f}ms total, {timing_stats['reconstruct_time']/n*1000:.4f}ms avg" + ) + total = ( + timing_stats["detection_time"] + + timing_stats["options_build_time"] + + timing_stats["flatten_time"] + + timing_stats["execute_time"] + + timing_stats["reconstruct_time"] + ) + print( + f" Total wrapper: {total*1000:.2f}ms total, {total/n*1000:.3f}ms avg" + ) + print( + f" Execute is {timing_stats['execute_time']/total*100:.1f}% of wrapper time" + ) + expected_tps = 1000 / (timing_stats["execute_time"] / n * 1000) + print(f" Expected tok/s from execute alone: {expected_tps:.1f}") + + wrapper.print_timing_stats = print_timing_stats + wrapper.timing_stats = timing_stats + + return wrapper diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index d6dff173072..7556ef60e19 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -1,3 +1,4 @@ +import logging from abc import ABC, abstractmethod from enum import Enum from typing import Any, Dict, Optional, Tuple, Type, TypedDict @@ -52,6 +53,8 @@ def forward( ATTENTION_REGISTRY: Dict[str, Type[Attention]] = {} +_RECURRENT_GATED_DELTA_RULE_OP = None +_TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = False def register_attention(name: str): @@ -64,6 +67,38 @@ def decorator(cls: Type[Attention]): return decorator +def _get_recurrent_gated_delta_rule_op(): + global _RECURRENT_GATED_DELTA_RULE_OP + global _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP + + if _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP: + return _RECURRENT_GATED_DELTA_RULE_OP + + _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True + try: + _RECURRENT_GATED_DELTA_RULE_OP = ( + torch.ops.llama.recurrent_gated_delta_rule.default + ) + return _RECURRENT_GATED_DELTA_RULE_OP + except (AttributeError, RuntimeError): + pass + + try: + from executorch.extension.llm.custom_ops import custom_ops # noqa: F401 + except (ImportError, OSError, RuntimeError): + logging.debug("Failed to import custom ops library", exc_info=True) + return None + + try: + _RECURRENT_GATED_DELTA_RULE_OP = ( + torch.ops.llama.recurrent_gated_delta_rule.default + ) + except (AttributeError, RuntimeError): + _RECURRENT_GATED_DELTA_RULE_OP = None + + return _RECURRENT_GATED_DELTA_RULE_OP + + class KVCache(nn.Module): def __init__( self, @@ -725,7 +760,7 @@ def _apply_causal_conv(self, mixed_qkv: torch.Tensor) -> torch.Tensor: out = F.silu(out[:, :, -seq_len:]).to(mixed_qkv.dtype) return out.transpose(1, 2).contiguous() - def _recurrent_gated_delta_rule( + def _gated_delta_rule_op( self, query: torch.Tensor, key: torch.Tensor, @@ -733,20 +768,35 @@ def _recurrent_gated_delta_rule( g: torch.Tensor, beta: torch.Tensor, ) -> torch.Tensor: - # query/key/value: (batch, seq_len, num_heads, head_dim) - # g/beta: (batch, seq_len, num_heads) - initial_dtype = query.dtype - query = _l2norm(query, dim=-1, eps=1e-6) - key = _l2norm(key, dim=-1, eps=1e-6) - query, key, value, beta, g = [ - x.transpose(1, 2).contiguous().to(torch.float32) - for x in (query, key, value, beta, g) - ] + batch_size = query.shape[0] + recurrent_gated_delta_rule_op = _get_recurrent_gated_delta_rule_op() + if recurrent_gated_delta_rule_op is not None: + return recurrent_gated_delta_rule_op( + query, + key, + value, + g, + beta, + self.recurrent_state[:batch_size], + ) + return self._naive_gated_delta_rule_op( + query, + key, + value, + g, + beta, + ) - batch_size, num_heads, sequence_length, k_head_dim = key.shape + def _naive_gated_delta_rule_op( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + ) -> torch.Tensor: + batch_size, num_heads, sequence_length, _ = key.shape v_head_dim = value.shape[-1] - scale = 1.0 / (query.shape[-1] ** 0.5) - query = query * scale core_attn_out = torch.zeros( batch_size, @@ -780,6 +830,36 @@ def _recurrent_gated_delta_rule( last_recurrent_state.to(self.recurrent_state.dtype) ) + return core_attn_out + + def _recurrent_gated_delta_rule( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + ) -> torch.Tensor: + # query/key/value: (batch, seq_len, num_heads, head_dim) + # g/beta: (batch, seq_len, num_heads) + initial_dtype = query.dtype + query = _l2norm(query, dim=-1, eps=1e-6) + key = _l2norm(key, dim=-1, eps=1e-6) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) + for x in (query, key, value, beta, g) + ] + + scale = 1.0 / (query.shape[-1] ** 0.5) + query = query * scale + + core_attn_out = self._gated_delta_rule_op( + query, + key, + value, + g, + beta, + ) return core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) def forward( diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 0f38191d807..9cf1b4b4bf0 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -123,7 +123,7 @@ "qwen2_5_1_5b": "Qwen/Qwen2.5-1.5B", "qwen2_5_coder_32b": "Qwen/Qwen2.5-Coder-32B-Instruct", "phi_4_mini": "microsoft/Phi-4-mini-instruct", - "smollm2": "HuggingFaceTB/SmolLM-135M", + "smollm2": "HuggingFaceTB/SmolLM2-135M", "qwen3_0_6b": "Qwen/Qwen3-0.6B", "qwen3_1_7b": "Qwen/Qwen3-1.7B", "qwen3_4b": "Qwen/Qwen3-4B", diff --git a/examples/models/llama/tests/test_export_llama_lib.py b/examples/models/llama/tests/test_export_llama_lib.py index 130a55f658c..c96fea8c215 100644 --- a/examples/models/llama/tests/test_export_llama_lib.py +++ b/examples/models/llama/tests/test_export_llama_lib.py @@ -5,7 +5,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import json +import tempfile import unittest +from pathlib import Path from executorch.devtools.backend_debug import get_delegation_info @@ -25,6 +28,7 @@ from executorch.examples.models.llama.export_llama_lib import ( _export_llama, + _prepare_for_llama_export, build_args_parser, get_quantizer_and_quant_params, ) @@ -37,6 +41,39 @@ class ExportLlamaLibTest(unittest.TestCase): + def _make_tiny_qwen35_params(self) -> dict: + return { + "dim": 64, + "hidden_dim": 128, + "n_heads": 4, + "head_dim": 16, + "n_kv_heads": 2, + "n_layers": 4, + "norm_eps": 1e-6, + "rope_theta": 10000000.0, + "use_scaled_rope": False, + "vocab_size": 256, + "use_hf_rope": True, + "partial_rotary_factor": 0.25, + "attention_qkv_bias": False, + "use_qk_norm": True, + "qk_norm_before_rope": True, + "attention_type": "mha", + "use_q_gate": True, + "rms_norm_add_unit_offset": True, + "linear_conv_kernel_dim": 4, + "linear_key_head_dim": 8, + "linear_value_head_dim": 8, + "linear_num_key_heads": 4, + "linear_num_value_heads": 4, + "layer_types": [ + "linear_attention", + "full_attention", + "linear_attention", + "full_attention", + ], + } + def test_has_expected_ops_and_op_counts(self): """ Checks the presence of unwanted expensive ops. @@ -66,6 +103,41 @@ def test_has_expected_ops_and_op_counts(self): for op, _op_info in delegation_info.delegation_by_operator.items(): self.assertTrue(op not in UNWANTED_OPS) + def test_tiny_qwen35_export_uses_recurrent_gated_delta_rule(self): + with tempfile.TemporaryDirectory() as temp_dir: + params_path = Path(temp_dir) / "tiny_qwen35.json" + params_path.write_text(json.dumps(self._make_tiny_qwen35_params())) + + parser = build_args_parser() + args = parser.parse_args( + [ + "--model", + "qwen3_5_0_8b", + "--params", + str(params_path), + "--use_kv_cache", + "--disable_dynamic_shape", + "--max_seq_length", + "8", + "--max_context_length", + "8", + ] + ) + + llm_config = LlmConfig.from_args(args) + builder = _prepare_for_llama_export(llm_config).export() + assert builder.pre_autograd_graph_module is not None + + recurrent_nodes = [ + node + for node in builder.pre_autograd_graph_module.graph.nodes + if "auto_functionalized_v2" in str(node.target) + and node.args + and "llama.recurrent_gated_delta_rule" in str(node.args[0]) + ] + + self.assertEqual(len(recurrent_nodes), 2) + @unittest.skipUnless(HAS_ARM_BACKEND, "ARM backend not available") def test_get_quantizer_and_quant_params_returns_tosa_quantizer(self): llm_config = LlmConfig() diff --git a/examples/models/llama/tests/test_qwen3_5_attention.py b/examples/models/llama/tests/test_qwen3_5_attention.py index 5a9f67d57cf..ba96a96aa43 100644 --- a/examples/models/llama/tests/test_qwen3_5_attention.py +++ b/examples/models/llama/tests/test_qwen3_5_attention.py @@ -6,7 +6,9 @@ import unittest +import executorch.examples.models.llama.attention as attention_module import torch + from executorch.examples.models.llama.attention import ATTENTION_REGISTRY from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.norm import RMSNorm @@ -123,6 +125,109 @@ def test_gated_deltanet_no_input_pos_does_not_leak_state(self): torch.allclose(state_after_first, state_after_second, atol=1e-5) ) + def test_gated_deltanet_chunked_prefill_matches_full_sequence(self): + torch.manual_seed(0) + args = self._make_args( + use_kv_cache=True, + use_q_gate=True, + linear_conv_kernel_dim=4, + linear_key_head_dim=4, + linear_value_head_dim=4, + linear_num_key_heads=2, + linear_num_value_heads=4, + ) + rope = Rope(args) + attn_full = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope) + attn_chunked = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope) + attn_chunked.load_state_dict(attn_full.state_dict()) + + x = torch.randn(1, 5, args.dim) + dummy_freq = torch.zeros(1, 1) + + full_output, _ = attn_full( + x, + dummy_freq, + dummy_freq, + input_pos=torch.tensor([0], dtype=torch.long), + ) + + chunk_outputs = [] + for start, end in ((0, 3), (3, 4), (4, 5)): + output, _ = attn_chunked( + x[:, start:end], + dummy_freq, + dummy_freq, + input_pos=torch.tensor([start], dtype=torch.long), + ) + chunk_outputs.append(output) + + chunked_output = torch.cat(chunk_outputs, dim=1) + + self.assertTrue(torch.allclose(chunked_output, full_output, atol=1e-5)) + self.assertTrue( + torch.allclose( + attn_chunked.recurrent_state, attn_full.recurrent_state, atol=1e-5 + ) + ) + self.assertTrue( + torch.allclose(attn_chunked.conv_state, attn_full.conv_state, atol=1e-5) + ) + + def test_gated_deltanet_custom_op_matches_fallback(self): + recurrent_op = attention_module._get_recurrent_gated_delta_rule_op() + if recurrent_op is None: + self.skipTest("llama::recurrent_gated_delta_rule is not available") + + torch.manual_seed(0) + args = self._make_args( + use_kv_cache=True, + use_q_gate=True, + linear_conv_kernel_dim=4, + linear_key_head_dim=4, + linear_value_head_dim=4, + linear_num_key_heads=2, + linear_num_value_heads=4, + ) + rope = Rope(args) + attn_custom = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope) + attn_fallback = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope) + attn_fallback.load_state_dict(attn_custom.state_dict()) + + query = torch.randn(1, 3, attn_custom.num_v_heads, attn_custom.head_k_dim) + key = torch.randn(1, 3, attn_custom.num_v_heads, attn_custom.head_k_dim) + value = torch.randn(1, 3, attn_custom.num_v_heads, attn_custom.head_v_dim) + g = torch.randn(1, 3, attn_custom.num_v_heads) + beta = torch.sigmoid(torch.randn(1, 3, attn_custom.num_v_heads)) + + original_op = attention_module._RECURRENT_GATED_DELTA_RULE_OP + original_tried_loading = ( + attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP + ) + try: + attention_module._RECURRENT_GATED_DELTA_RULE_OP = recurrent_op + attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True + custom_output = attn_custom._recurrent_gated_delta_rule( + query, key, value, g, beta + ) + + attention_module._RECURRENT_GATED_DELTA_RULE_OP = None + attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True + fallback_output = attn_fallback._recurrent_gated_delta_rule( + query, key, value, g, beta + ) + finally: + attention_module._RECURRENT_GATED_DELTA_RULE_OP = original_op + attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = ( + original_tried_loading + ) + + self.assertTrue(torch.allclose(custom_output, fallback_output, atol=1e-5)) + self.assertTrue( + torch.allclose( + attn_custom.recurrent_state, attn_fallback.recurrent_state, atol=1e-5 + ) + ) + if __name__ == "__main__": unittest.main() diff --git a/examples/models/qwen3_5_moe/main.cpp b/examples/models/qwen3_5_moe/main.cpp index bae4cfc183c..00c91a685e1 100644 --- a/examples/models/qwen3_5_moe/main.cpp +++ b/examples/models/qwen3_5_moe/main.cpp @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -58,11 +59,13 @@ int main(int argc, char** argv) { llm::Stats stats; +#ifdef EXECUTORCH_BUILD_CUDA // GPU memory before load - size_t gpu_free_bytes, gpu_total_bytes; + size_t gpu_free_bytes = 0, gpu_total_bytes = 0; cudaMemGetInfo(&gpu_free_bytes, &gpu_total_bytes); stats.gpu_total_bytes = gpu_total_bytes; stats.gpu_free_before_load_bytes = gpu_free_bytes; +#endif stats.model_load_start_ms = llm::time_in_ms(); @@ -127,9 +130,11 @@ int main(int argc, char** argv) { stats.model_load_end_ms = llm::time_in_ms(); +#ifdef EXECUTORCH_BUILD_CUDA // GPU memory after load cudaMemGetInfo(&gpu_free_bytes, &gpu_total_bytes); stats.gpu_free_after_load_bytes = gpu_free_bytes; +#endif // Get EOS ids auto eos_ids = llm::get_eos_ids(tokenizer.get(), module.get()); @@ -155,7 +160,7 @@ int main(int argc, char** argv) { } auto prompt_tokens = std::move(*encode_result); int64_t num_prompt_tokens = prompt_tokens.size(); - printf("Prompt tokens: %ld\n", num_prompt_tokens); + printf("Prompt tokens: %" PRId64 "\n", num_prompt_tokens); stats.num_prompt_tokens = num_prompt_tokens; stats.inference_start_ms = llm::time_in_ms(); @@ -209,7 +214,7 @@ int main(int argc, char** argv) { double prefill_ms = (double)(stats.prompt_eval_end_ms - stats.inference_start_ms); printf( - "Prefill: %ld tokens in %.1f ms (%.1f tok/s)\n", + "Prefill: %" PRId64 " tokens in %.1f ms (%.1f tok/s)\n", num_prompt_tokens, prefill_ms, num_prompt_tokens * 1000.0 / prefill_ms); @@ -290,17 +295,19 @@ int main(int argc, char** argv) { double decode_ms = (double)(stats.inference_end_ms - stats.prompt_eval_end_ms); printf( - "Decode: %ld tokens in %.1f ms (%.1f tok/s)\n", + "Decode: %" PRId64 " tokens in %.1f ms (%.1f tok/s)\n", num_generated, decode_ms, num_generated * 1000.0 / decode_ms); - printf("Prompt tokens: %ld\n", num_prompt_tokens); + printf("Prompt tokens: %" PRId64 "\n", num_prompt_tokens); +#ifdef EXECUTORCH_BUILD_CUDA // GPU memory after generation cudaMemGetInfo(&gpu_free_bytes, &gpu_total_bytes); stats.gpu_free_after_generate_bytes = gpu_free_bytes; stats.gpu_peak_usage_mb = (stats.gpu_total_bytes - gpu_free_bytes) / 1024.0 / 1024.0; +#endif llm::print_report(stats); diff --git a/examples/models/smollm2/135M_config.json b/examples/models/smollm2/135M_config.json index 604c7e94ab5..1e3bc8ee0cb 100644 --- a/examples/models/smollm2/135M_config.json +++ b/examples/models/smollm2/135M_config.json @@ -6,7 +6,7 @@ "n_kv_heads": 3, "n_layers": 30, "norm_eps": 1e-05, - "rope_theta": 10000.0, + "rope_theta": 100000.0, "use_scaled_rope": false, "vocab_size": 49152, "use_hf_rope": false, diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt index ba91f444287..eb2b6f096a1 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt @@ -17,7 +17,6 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.commons.io.FileUtils import org.junit.Assert import org.junit.Before -import org.junit.Ignore import org.junit.Test import org.junit.runner.RunWith import org.pytorch.executorch.TestFileUtils.getTestFilePath @@ -40,48 +39,49 @@ class ModuleInstrumentationTest { inputStream.close() } - @Ignore( - "The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward " - ) @Test @Throws(IOException::class, URISyntaxException::class) fun testModuleLoadAndForward() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - val results = module.forward() - Assert.assertTrue(results[0].isTensor) + try { + val results = module.forward(EValue.from(dummyInput())) + Assert.assertTrue(results[0].isTensor) + } finally { + module.destroy() + } } @Test @Throws(IOException::class, URISyntaxException::class) fun testMethodMetadata() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + module.destroy() } - @Ignore( - "The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward " - ) @Test @Throws(IOException::class) fun testModuleLoadMethodAndForward() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + try { + module.loadMethod(FORWARD_METHOD) - module.loadMethod(FORWARD_METHOD) - - val results = module.forward() - Assert.assertTrue(results[0].isTensor) + val results = module.forward(EValue.from(dummyInput())) + Assert.assertTrue(results[0].isTensor) + } finally { + module.destroy() + } } - @Ignore( - "The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward " - ) @Test @Throws(IOException::class) fun testModuleLoadForwardExplicit() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - val results = module.execute(FORWARD_METHOD) - Assert.assertTrue(results[0].isTensor) + try { + val results = module.execute(FORWARD_METHOD, EValue.from(dummyInput())) + Assert.assertTrue(results[0].isTensor) + } finally { + module.destroy() + } } @Test(expected = RuntimeException::class) @@ -94,15 +94,18 @@ class ModuleInstrumentationTest { @Throws(IOException::class) fun testModuleLoadMethodNonExistantMethod() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - val exception = - Assert.assertThrows(ExecutorchRuntimeException::class.java) { - module.loadMethod(NONE_METHOD) - } - Assert.assertEquals( - ExecutorchRuntimeException.INVALID_ARGUMENT, - exception.getErrorCode(), - ) + try { + val exception = + Assert.assertThrows(ExecutorchRuntimeException::class.java) { + module.loadMethod(NONE_METHOD) + } + Assert.assertEquals( + ExecutorchRuntimeException.INVALID_ARGUMENT, + exception.getErrorCode(), + ) + } finally { + module.destroy() + } } @Test(expected = RuntimeException::class) @@ -135,9 +138,6 @@ class ModuleInstrumentationTest { Assert.assertThrows(IllegalStateException::class.java) { module.forward() } } - @Ignore( - "The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward " - ) @Test @Throws(InterruptedException::class, IOException::class) fun testForwardFromMultipleThreads() { @@ -151,7 +151,7 @@ class ModuleInstrumentationTest { try { latch.countDown() latch.await(5000, TimeUnit.MILLISECONDS) - val results = module.forward() + val results = module.forward(EValue.from(dummyInput())) Assert.assertTrue(results[0].isTensor) completed.incrementAndGet() } catch (_: InterruptedException) {} @@ -168,6 +168,7 @@ class ModuleInstrumentationTest { } Assert.assertEquals(numThreads.toLong(), completed.get().toLong()) + module.destroy() } companion object { @@ -176,5 +177,8 @@ class ModuleInstrumentationTest { private const val NON_PTE_FILE_NAME = "/test.txt" private const val FORWARD_METHOD = "forward" private const val NONE_METHOD = "none" + private val inputShape = longArrayOf(1, 3, 224, 224) + + private fun dummyInput(): Tensor = Tensor.ones(inputShape, DType.FLOAT) } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java index e0fda73cc06..e72ed9e3d28 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java @@ -161,6 +161,11 @@ public ExecutorchRuntimeException(int errorCode, String details) { this.errorCode = errorCode; } + public ExecutorchRuntimeException(int errorCode, String details, Throwable cause) { + super(ErrorHelper.formatMessage(errorCode, details), cause); + this.errorCode = errorCode; + } + /** Returns the numeric error code from {@code runtime/core/error.h}. */ public int getErrorCode() { return errorCode; diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 2c0117dc576..94c0efff335 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -594,21 +595,19 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { jint load() { if (!runner_) { - ET_LOG( - Error, - "ExecuTorchLlmJni::load() called but runner_ is null. " - "The model runner was not created or failed to initialize due to a " - "previous configuration or initialization error. " - "Model type category: %d.", - model_type_category_); + std::stringstream ss; + ss << "Model runner was not created. model_type_category=" + << model_type_category_ + << ". Valid values: " << MODEL_TYPE_CATEGORY_LLM << " (LLM), " + << MODEL_TYPE_CATEGORY_MULTIMODAL << " (Multimodal)"; + executorch::jni_helper::throwExecutorchException( + static_cast(Error::InvalidState), ss.str().c_str()); return static_cast(Error::InvalidState); } const auto load_result = static_cast(runner_->load()); if (load_result != static_cast(Error::Ok)) { - ET_LOG( - Error, - "ExecuTorchLlmJni::load() failed in runner_->load() with error code %d.", - static_cast(load_result)); + executorch::jni_helper::throwExecutorchException( + static_cast(load_result), "Failed to load model runner"); } return load_result; } diff --git a/extension/aten_util/make_aten_functor_from_et_functor.h b/extension/aten_util/make_aten_functor_from_et_functor.h index 8e1c2bf0143..67e7344330e 100644 --- a/extension/aten_util/make_aten_functor_from_et_functor.h +++ b/extension/aten_util/make_aten_functor_from_et_functor.h @@ -15,7 +15,8 @@ #pragma once #include #include -#if __cplusplus < 201703L +#if (defined(_MSC_VER) && (!defined(_MSVC_LANG) || _MSVC_LANG < 201703L)) || \ + (!defined(_MSC_VER) && __cplusplus < 201703L) #error "This header requires C++17" #endif #include diff --git a/extension/llm/custom_ops/custom_ops.py b/extension/llm/custom_ops/custom_ops.py index 9aacded4b4c..e0b009d7a13 100644 --- a/extension/llm/custom_ops/custom_ops.py +++ b/extension/llm/custom_ops/custom_ops.py @@ -11,7 +11,9 @@ # pyre-unsafe import logging +import os +from pathlib import Path from typing import Tuple import torch @@ -21,33 +23,84 @@ from torch.library import impl aten = torch.ops.aten +_CUSTOM_OPS_DLL_DIR_HANDLES = [] -try: - op = torch.ops.llama.sdpa_with_kv_cache.default - assert op is not None - op2 = torch.ops.llama.fast_hadamard_transform.default - assert op2 is not None -except: - # This is needed to ensure that custom ops are registered - from executorch.extension.pybindings import portable_lib # noqa # usort: skip - # Ideally package is installed in only one location but usage of - # PYATHONPATH can result in multiple locations. - # ATM this is mainly used in CI for qnn runner. Will need to revisit this - from pathlib import Path +def _is_custom_ops_registered() -> bool: + try: + torch.ops.llama.sdpa_with_kv_cache.default + torch.ops.llama.fast_hadamard_transform.default + return True + except (AttributeError, RuntimeError): + return False + + +def _get_custom_ops_library_override() -> Path | None: + override = os.environ.get("EXECUTORCH_CUSTOM_OPS_AOT_LIB") + if override is None: + return None + + lib_path = Path(override).expanduser().resolve() + if not lib_path.is_file(): + raise FileNotFoundError( + "EXECUTORCH_CUSTOM_OPS_AOT_LIB must point to an existing " + f"custom_ops_aot_lib, but got {lib_path}" + ) + return lib_path + + +def _find_custom_ops_library() -> Path: + override = _get_custom_ops_library_override() + if override is not None: + return override package_path = Path(__file__).parent.resolve() - logging.info(f"Looking for libcustom_ops_aot_lib.so in {package_path}") + candidates = [] + patterns = ( + "**/custom_ops_aot_lib.dll", + "**/libcustom_ops_aot_lib.so", + "**/libcustom_ops_aot_lib.dylib", + ) + + for pattern in patterns: + candidates.extend(package_path.glob(pattern)) + + libs = sorted({path.resolve() for path in candidates if path.is_file()}) + if not libs: + raise FileNotFoundError( + f"Could not find custom_ops_aot_lib under {package_path}" + ) + return max(libs, key=lambda path: path.stat().st_mtime) + + +def _load_custom_ops_library() -> None: + try: + # This is needed to ensure that custom ops are registered when + # portable_lib is available in the current environment. + from executorch.extension.pybindings import portable_lib # noqa # usort: skip + except ImportError: + portable_lib = None + + lib_path = _find_custom_ops_library() + logging.info(f"Loading custom ops library: {lib_path}") + + if os.name == "nt": + _CUSTOM_OPS_DLL_DIR_HANDLES.append(os.add_dll_directory(str(lib_path.parent))) + torch_lib_dir = Path(torch.__file__).resolve().parent / "lib" + if torch_lib_dir.is_dir(): + _CUSTOM_OPS_DLL_DIR_HANDLES.append(os.add_dll_directory(str(torch_lib_dir))) - libs = list(package_path.glob("**/*custom_ops_aot_lib.*")) + torch.ops.load_library(lib_path) - assert len(libs) == 1, f"Expected 1 library but got {len(libs)}" - logging.info(f"Loading custom ops library: {libs[0]}") - torch.ops.load_library(libs[0]) - op = torch.ops.llama.sdpa_with_kv_cache.default - assert op is not None - op2 = torch.ops.llama.fast_hadamard_transform.default - assert op2 is not None + # Keep the import alive to avoid lint complaints in environments where + # portable_lib is needed for symbol resolution. + _ = portable_lib + + +if not _is_custom_ops_registered(): + _load_custom_ops_library() + if not _is_custom_ops_registered(): + raise RuntimeError("Failed to register ExecuTorch custom ops library") custom_ops_lib = torch.library.Library("llama", "IMPL") @@ -271,6 +324,87 @@ def update_cache_with_indices_meta( return torch.empty((1,), dtype=value.dtype, device="meta") +def _validate_recurrent_gated_delta_rule_params( + query, + key, + value, + g, + beta, + recurrent_state, +): + assert ( + query.dim() == 4 + ), f"Expected query to be 4 dimensional but got {query.dim()} dimensions." + assert ( + key.dim() == 4 + ), f"Expected key to be 4 dimensional but got {key.dim()} dimensions." + assert ( + value.dim() == 4 + ), f"Expected value to be 4 dimensional but got {value.dim()} dimensions." + assert g.dim() == 3, f"Expected g to be 3 dimensional but got {g.dim()} dimensions." + assert ( + beta.dim() == 3 + ), f"Expected beta to be 3 dimensional but got {beta.dim()} dimensions." + assert ( + recurrent_state.dim() == 4 + ), f"Expected recurrent_state to be 4 dimensional but got {recurrent_state.dim()} dimensions." + + for name, tensor in { + "query": query, + "key": key, + "value": value, + "g": g, + "beta": beta, + "recurrent_state": recurrent_state, + }.items(): + assert ( + tensor.dtype == torch.float32 + ), f"Expected {name} to be float32 but got {tensor.dtype}" + + assert ( + query.shape == key.shape + ), f"Expected query and key to have matching shapes but got {query.shape} and {key.shape}" + assert ( + query.shape[:3] == value.shape[:3] + ), f"Expected query and value to match in batch/head/sequence dims but got {query.shape} and {value.shape}" + assert ( + g.shape == query.shape[:3] + ), f"Expected g to match query batch/head/sequence dims but got {g.shape} and {query.shape}" + assert ( + beta.shape == query.shape[:3] + ), f"Expected beta to match query batch/head/sequence dims but got {beta.shape} and {query.shape}" + assert recurrent_state.shape == ( + query.size(0), + query.size(1), + query.size(3), + value.size(3), + ), ( + "Expected recurrent_state to have shape " + f"{(query.size(0), query.size(1), query.size(3), value.size(3))} " + f"but got {recurrent_state.shape}" + ) + + +@impl(custom_ops_lib, "recurrent_gated_delta_rule", "Meta") +def recurrent_gated_delta_rule_meta( + query, + key, + value, + g, + beta, + recurrent_state, +): + _validate_recurrent_gated_delta_rule_params( + query, + key, + value, + g, + beta, + recurrent_state, + ) + return torch.empty_like(value) + + def _validate_quantized_sdpa_params( query, key, diff --git a/extension/llm/custom_ops/op_fast_hadamard_transform_aten.cpp b/extension/llm/custom_ops/op_fast_hadamard_transform_aten.cpp index 146ac3cc298..d48f593868c 100644 --- a/extension/llm/custom_ops/op_fast_hadamard_transform_aten.cpp +++ b/extension/llm/custom_ops/op_fast_hadamard_transform_aten.cpp @@ -13,14 +13,40 @@ namespace torch::executor::native { namespace { +template +auto to_et_arg(AType&& value) { + return executorch::extension::internal::type_convert( + std::forward(value)); +} + +at::Tensor& copy_et_result_to_out(Tensor& et_result, at::Tensor& out) { + auto converted_result = + executorch::extension::internal::type_convert( + et_result) + .call(); + at::native::resize_output(out, converted_result.sizes()); + out.copy_(converted_result); + return out; +} + Tensor& fast_hadamard_transform_out_no_context(const Tensor& vec, Tensor& out) { executorch::aten::RuntimeContext context; return fast_hadamard_transform_out(context, vec, out); } + +at::Tensor& fast_hadamard_transform_out_aten( + const at::Tensor& vec, + at::Tensor& out) { + auto vec_et = to_et_arg(vec); + auto out_et = to_et_arg(out); + auto& et_result = + fast_hadamard_transform_out_no_context(vec_et.call(), out_et.call()); + return copy_et_result_to_out(et_result, out); +} + at::Tensor fast_hadamard_transform_aten(const at::Tensor& vec) { auto out = at::empty_like(vec); - WRAP_TO_ATEN(fast_hadamard_transform_out_no_context, 1) - (vec, out); + fast_hadamard_transform_out_aten(vec, out); return out; } } // namespace @@ -38,6 +64,5 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { torch::executor::native::fast_hadamard_transform_aten); m.impl( "fast_hadamard_transform.out", - WRAP_TO_ATEN( - torch::executor::native::fast_hadamard_transform_out_no_context, 1)); + torch::executor::native::fast_hadamard_transform_out_aten); } diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index 72bddce7b5b..76ee9cb915f 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -15,6 +15,10 @@ #include // @lint-ignore CLANGTIDY facebook-unused-include-check #include +#include +#include +#include +#include #ifdef ET_USE_THREADPOOL #include @@ -178,6 +182,68 @@ bool validate_cache_params( return true; } +bool validate_recurrent_gated_delta_rule_args( + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& g, + const Tensor& beta, + const Tensor& recurrent_state) { + ET_CHECK_OR_RETURN_FALSE(query.dim() == 4, "query must be a 4D tensor"); + ET_CHECK_OR_RETURN_FALSE(key.dim() == 4, "key must be a 4D tensor"); + ET_CHECK_OR_RETURN_FALSE(value.dim() == 4, "value must be a 4D tensor"); + ET_CHECK_OR_RETURN_FALSE(g.dim() == 3, "g must be a 3D tensor"); + ET_CHECK_OR_RETURN_FALSE(beta.dim() == 3, "beta must be a 3D tensor"); + ET_CHECK_OR_RETURN_FALSE( + recurrent_state.dim() == 4, "recurrent_state must be a 4D tensor"); + + ET_CHECK_OR_RETURN_FALSE( + query.scalar_type() == ScalarType::Float, "query must be float32"); + ET_CHECK_OR_RETURN_FALSE( + key.scalar_type() == ScalarType::Float, "key must be float32"); + ET_CHECK_OR_RETURN_FALSE( + value.scalar_type() == ScalarType::Float, "value must be float32"); + ET_CHECK_OR_RETURN_FALSE( + g.scalar_type() == ScalarType::Float, "g must be float32"); + ET_CHECK_OR_RETURN_FALSE( + beta.scalar_type() == ScalarType::Float, "beta must be float32"); + ET_CHECK_OR_RETURN_FALSE( + recurrent_state.scalar_type() == ScalarType::Float, + "recurrent_state must be float32"); + + ET_CHECK_OR_RETURN_FALSE( + query.size(0) == key.size(0) && query.size(1) == key.size(1) && + query.size(2) == key.size(2) && query.size(3) == key.size(3), + "query and key must have matching shapes"); + ET_CHECK_OR_RETURN_FALSE( + query.size(0) == value.size(0) && query.size(1) == value.size(1) && + query.size(2) == value.size(2), + "query and value must match in batch/head/sequence dims"); + ET_CHECK_OR_RETURN_FALSE( + g.size(0) == query.size(0) && g.size(1) == query.size(1) && + g.size(2) == query.size(2), + "g must match query batch/head/sequence dims"); + ET_CHECK_OR_RETURN_FALSE( + beta.size(0) == query.size(0) && beta.size(1) == query.size(1) && + beta.size(2) == query.size(2), + "beta must match query batch/head/sequence dims"); + ET_CHECK_OR_RETURN_FALSE( + recurrent_state.size(0) == query.size(0) && + recurrent_state.size(1) == query.size(1) && + recurrent_state.size(2) == query.size(3) && + recurrent_state.size(3) == value.size(3), + "recurrent_state shape must match [B, H, K, V]"); + + for (const Tensor* tensor : + {&query, &key, &value, &g, &beta, &recurrent_state}) { + ET_CHECK_OR_RETURN_FALSE( + is_contiguous_dim_order((*tensor).dim_order().data(), (*tensor).dim()), + "recurrent gated delta rule expects contiguous inputs"); + } + + return true; +} + // TODO: seq_length is not yet used for copy void update_cache( const Tensor& projected_value, @@ -610,6 +676,133 @@ Tensor& sdpa_with_kv_cache_out( return output; } + +Tensor& recurrent_gated_delta_rule_out( + RuntimeContext& ctx, + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& g, + const Tensor& beta, + Tensor& recurrent_state, + Tensor& output) { + ET_KERNEL_CHECK_MSG( + ctx, + resize_tensor(output, value.sizes()) == Error::Ok, + InvalidArgument, + output, + "Failed to resize recurrent_gated_delta_rule output tensor."); + ET_KERNEL_CHECK( + ctx, + validate_recurrent_gated_delta_rule_args( + query, key, value, g, beta, recurrent_state), + InvalidArgument, + output); + ET_KERNEL_CHECK( + ctx, output.scalar_type() == ScalarType::Float, InvalidArgument, output); + ET_KERNEL_CHECK( + ctx, + is_contiguous_dim_order(output.dim_order().data(), output.dim()), + InvalidArgument, + output); + + const auto batch_size = query.size(0); + const auto num_heads = query.size(1); + const auto sequence_length = query.size(2); + const auto k_head_dim = query.size(3); + const auto v_head_dim = value.size(3); + + const auto q_batch_stride = num_heads * sequence_length * k_head_dim; + const auto q_head_stride = sequence_length * k_head_dim; + const auto q_seq_stride = k_head_dim; + + const auto value_batch_stride = num_heads * sequence_length * v_head_dim; + const auto value_head_stride = sequence_length * v_head_dim; + const auto value_seq_stride = v_head_dim; + + const auto gv_batch_stride = num_heads * sequence_length; + const auto gv_head_stride = sequence_length; + + const auto state_batch_stride = num_heads * k_head_dim * v_head_dim; + const auto state_head_stride = k_head_dim * v_head_dim; + + const auto* query_data = query.const_data_ptr(); + const auto* key_data = key.const_data_ptr(); + const auto* value_data = value.const_data_ptr(); + const auto* g_data = g.const_data_ptr(); + const auto* beta_data = beta.const_data_ptr(); + auto* recurrent_state_data = recurrent_state.mutable_data_ptr(); + auto* output_data = output.mutable_data_ptr(); + std::vector kv_mem(v_head_dim); + std::vector delta(v_head_dim); + + for (int64_t batch = 0; batch < batch_size; ++batch) { + for (int64_t head = 0; head < num_heads; ++head) { + const auto q_offset = batch * q_batch_stride + head * q_head_stride; + const auto value_offset = + batch * value_batch_stride + head * value_head_stride; + const auto gv_offset = batch * gv_batch_stride + head * gv_head_stride; + const auto state_offset = + batch * state_batch_stride + head * state_head_stride; + + const auto* q_head = query_data + q_offset; + const auto* k_head = key_data + q_offset; + const auto* value_head = value_data + value_offset; + const auto* g_head = g_data + gv_offset; + const auto* beta_head = beta_data + gv_offset; + auto* state_head = recurrent_state_data + state_offset; + auto* output_head = output_data + value_offset; + + for (int64_t token = 0; token < sequence_length; ++token) { + const auto* q_t = q_head + token * q_seq_stride; + const auto* k_t = k_head + token * q_seq_stride; + const auto* v_t = value_head + token * value_seq_stride; + auto* output_t = output_head + token * value_seq_stride; + + const float g_t = std::exp(g_head[token]); + const float beta_t = beta_head[token]; + + if (g_t != 1.0f) { + for (int64_t idx = 0; idx < state_head_stride; ++idx) { + state_head[idx] *= g_t; + } + } + + std::fill(kv_mem.begin(), kv_mem.end(), 0.0f); + for (int64_t k_idx = 0; k_idx < k_head_dim; ++k_idx) { + const float key_value = k_t[k_idx]; + const auto* state_row = state_head + k_idx * v_head_dim; + for (int64_t v_idx = 0; v_idx < v_head_dim; ++v_idx) { + kv_mem[v_idx] += state_row[v_idx] * key_value; + } + } + + for (int64_t v_idx = 0; v_idx < v_head_dim; ++v_idx) { + delta[v_idx] = (v_t[v_idx] - kv_mem[v_idx]) * beta_t; + } + + for (int64_t k_idx = 0; k_idx < k_head_dim; ++k_idx) { + const float key_value = k_t[k_idx]; + auto* state_row = state_head + k_idx * v_head_dim; + for (int64_t v_idx = 0; v_idx < v_head_dim; ++v_idx) { + state_row[v_idx] += key_value * delta[v_idx]; + } + } + + std::fill(output_t, output_t + v_head_dim, 0.0f); + for (int64_t k_idx = 0; k_idx < k_head_dim; ++k_idx) { + const float query_value = q_t[k_idx]; + const auto* state_row = state_head + k_idx * v_head_dim; + for (int64_t v_idx = 0; v_idx < v_head_dim; ++v_idx) { + output_t[v_idx] += state_row[v_idx] * query_value; + } + } + } + } + } + + return output; +} } // namespace native } // namespace executor } // namespace torch @@ -628,3 +821,36 @@ EXECUTORCH_LIBRARY( llama, "custom_quantized_sdpa.out", torch::executor::native::custom_quantized_sdpa_out); + +namespace { + +void recurrent_gated_delta_rule_out_boxed( + executorch::runtime::KernelRuntimeContext& ctx, + executorch::runtime::Span stack) { + ET_KERNEL_CHECK_MSG( + ctx, + stack.size() == 7, + InvalidProgram, + /* void */, + "Expected %zu args, got %zu", + static_cast(7), + stack.size()); + + auto& query = stack[0]->toTensor(); + auto& key = stack[1]->toTensor(); + auto& value = stack[2]->toTensor(); + auto& g = stack[3]->toTensor(); + auto& beta = stack[4]->toTensor(); + auto& recurrent_state = stack[5]->toTensor(); + auto& output = stack[6]->toTensor(); + + (void)torch::executor::native::recurrent_gated_delta_rule_out( + ctx, query, key, value, g, beta, recurrent_state, output); +} + +const auto recurrent_gated_delta_rule_out_registration = + executorch::runtime::register_kernel(executorch::runtime::Kernel( + "llama::recurrent_gated_delta_rule.out", + recurrent_gated_delta_rule_out_boxed)); + +} // namespace diff --git a/extension/llm/custom_ops/op_sdpa.h b/extension/llm/custom_ops/op_sdpa.h index 9d357eb6ea1..9f029f52f31 100644 --- a/extension/llm/custom_ops/op_sdpa.h +++ b/extension/llm/custom_ops/op_sdpa.h @@ -75,6 +75,16 @@ Tensor& custom_quantized_sdpa_out( const optional& v_scales, const bool is_seq_at_dim_1, Tensor& output); + +Tensor& recurrent_gated_delta_rule_out( + RuntimeContext& ctx, + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& g, + const Tensor& beta, + Tensor& recurrent_state, + Tensor& output); } // namespace native } // namespace executor } // namespace torch diff --git a/extension/llm/custom_ops/op_sdpa_aot.cpp b/extension/llm/custom_ops/op_sdpa_aot.cpp index 5bbf22d336e..d4d1122f614 100644 --- a/extension/llm/custom_ops/op_sdpa_aot.cpp +++ b/extension/llm/custom_ops/op_sdpa_aot.cpp @@ -17,6 +17,24 @@ namespace torch { namespace executor { namespace native { +namespace { +template +auto to_et_arg(AType&& value) { + return executorch::extension::internal::type_convert( + std::forward(value)); +} + +at::Tensor& copy_et_result_to_out(Tensor& et_result, at::Tensor& out) { + auto converted_result = + executorch::extension::internal::type_convert( + et_result) + .call(); + at::native::resize_output(out, converted_result.sizes()); + out.copy_(converted_result); + return out; +} +} // namespace + Tensor& sdpa_with_kv_cache_out_no_context( const Tensor& q_projected, const Tensor& k_projected, @@ -50,6 +68,20 @@ at::Tensor sdpa_with_kv_cache_aten( // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional scale); +at::Tensor& sdpa_with_kv_cache_out_aten( + const at::Tensor& q_projected, + const at::Tensor& k_projected, + const at::Tensor& v_projected, + at::Tensor& key_cache, + at::Tensor& value_cache, + const int64_t start_pos, + const int64_t seq_len, + const std::optional attn_mask, + const double dropout_p, + const bool is_causal, + const std::optional scale, + at::Tensor& output); + Tensor& custom_sdpa_out_no_context( const Tensor& q, const Tensor& k, @@ -77,6 +109,17 @@ at::Tensor custom_sdpa_aten( // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional scale); +at::Tensor& custom_sdpa_out_aten( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const int64_t start_pos, + const std::optional attn_mask, + const double dropout_p, + const bool is_causal, + const std::optional scale, + at::Tensor& output); + Tensor& custom_quantized_sdpa_out_no_context( const Tensor& q, const Tensor& k, @@ -118,6 +161,24 @@ at::Tensor custom_quantized_sdpa_aten( const std::optional& v_scales, const bool is_seq_at_dim_2); +at::Tensor& custom_quantized_sdpa_out_aten( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const int64_t start_pos, + const std::optional attn_mask, + const double dropout_p, + const bool is_causal, + const std::optional scale, + const std::optional& q_zero_points, + const std::optional& q_scales, + const std::optional& k_zero_points, + const std::optional& k_scales, + const std::optional& v_zero_points, + const std::optional& v_scales, + const bool is_seq_at_dim_2, + at::Tensor& output); + Tensor& update_cache_out_no_context( const Tensor& value, Tensor& cache, @@ -129,6 +190,12 @@ at::Tensor update_cache_aten( at::Tensor& cache, const int64_t start_pos); +at::Tensor& update_cache_out_aten( + const at::Tensor& value, + at::Tensor& cache, + const int64_t start_pos, + at::Tensor& output); + // New functions for update_cache_with_indices Tensor& update_cache_with_indices_out_no_context( const Tensor& value, @@ -143,6 +210,39 @@ at::Tensor update_cache_with_indices_aten( const int64_t start_pos, const at::Tensor& indices); +at::Tensor& update_cache_with_indices_out_aten( + const at::Tensor& value, + at::Tensor& cache, + const int64_t start_pos, + const at::Tensor& indices, + at::Tensor& output); + +Tensor& recurrent_gated_delta_rule_out_no_context( + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& g, + const Tensor& beta, + Tensor& recurrent_state, + Tensor& output); + +at::Tensor recurrent_gated_delta_rule_aten( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& g, + const at::Tensor& beta, + at::Tensor& recurrent_state); + +at::Tensor& recurrent_gated_delta_rule_out_aten( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& g, + const at::Tensor& beta, + at::Tensor& recurrent_state, + at::Tensor& output); + Tensor& sdpa_with_kv_cache_out_no_context( const Tensor& q_projected, const Tensor& k_projected, @@ -192,22 +292,59 @@ at::Tensor sdpa_with_kv_cache_aten( // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional scale) { auto output = at::empty_like(q_projected); - WRAP_TO_ATEN(sdpa_with_kv_cache_out_no_context, 11) - (q_projected, - k_projected, - v_projected, - key_cache, - value_cache, - start_pos, - seq_len, - attn_mask, - dropout_p, - is_causal, - scale, - output); + sdpa_with_kv_cache_out_aten( + q_projected, + k_projected, + v_projected, + key_cache, + value_cache, + start_pos, + seq_len, + attn_mask, + dropout_p, + is_causal, + scale, + output); return output; } +at::Tensor& sdpa_with_kv_cache_out_aten( + const at::Tensor& q_projected, + const at::Tensor& k_projected, + const at::Tensor& v_projected, + at::Tensor& key_cache, + at::Tensor& value_cache, + const int64_t start_pos, + const int64_t seq_len, + const std::optional attn_mask, + const double dropout_p, + const bool is_causal, + const std::optional scale, + at::Tensor& output) { + auto q_et = to_et_arg(q_projected); + auto k_et = to_et_arg(k_projected); + auto v_et = to_et_arg(v_projected); + auto key_cache_et = to_et_arg(key_cache); + auto value_cache_et = to_et_arg(value_cache); + auto attn_mask_et = to_et_arg>(attn_mask); + auto scale_et = to_et_arg>(scale); + auto output_et = to_et_arg(output); + auto& et_result = sdpa_with_kv_cache_out_no_context( + q_et.call(), + k_et.call(), + v_et.call(), + key_cache_et.call(), + value_cache_et.call(), + start_pos, + seq_len, + attn_mask_et.call(), + dropout_p, + is_causal, + scale_et.call(), + output_et.call()); + return copy_et_result_to_out(et_result, output); +} + Tensor& custom_sdpa_out_no_context( const Tensor& q, const Tensor& k, @@ -248,11 +385,40 @@ at::Tensor custom_sdpa_aten( // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional scale) { auto output = at::empty(q.sizes()); - WRAP_TO_ATEN(custom_sdpa_out_no_context, 8) - (q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output); + custom_sdpa_out_aten( + q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output); return output; } +at::Tensor& custom_sdpa_out_aten( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const int64_t start_pos, + const std::optional attn_mask, + const double dropout_p, + const bool is_causal, + const std::optional scale, + at::Tensor& output) { + auto q_et = to_et_arg(q); + auto k_et = to_et_arg(k); + auto v_et = to_et_arg(v); + auto attn_mask_et = to_et_arg>(attn_mask); + auto scale_et = to_et_arg>(scale); + auto output_et = to_et_arg(output); + auto& et_result = custom_sdpa_out_no_context( + q_et.call(), + k_et.call(), + v_et.call(), + start_pos, + attn_mask_et.call(), + dropout_p, + is_causal, + scale_et.call(), + output_et.call()); + return copy_et_result_to_out(et_result, output); +} + Tensor& custom_quantized_sdpa_out_no_context( const Tensor& q, const Tensor& k, @@ -314,26 +480,75 @@ at::Tensor custom_quantized_sdpa_aten( const std::optional& v_scales, const bool is_seq_at_dim_2) { auto output = at::empty(q.sizes()); - WRAP_TO_ATEN(custom_quantized_sdpa_out_no_context, 15) - (q, - k, - v, - start_pos, - attn_mask, - dropout_p, - is_causal, - scale, - q_zero_points, - q_scales, - k_zero_points, - k_scales, - v_zero_points, - v_scales, - is_seq_at_dim_2, - output); + custom_quantized_sdpa_out_aten( + q, + k, + v, + start_pos, + attn_mask, + dropout_p, + is_causal, + scale, + q_zero_points, + q_scales, + k_zero_points, + k_scales, + v_zero_points, + v_scales, + is_seq_at_dim_2, + output); return output; } +at::Tensor& custom_quantized_sdpa_out_aten( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const int64_t start_pos, + const std::optional attn_mask, + const double dropout_p, + const bool is_causal, + const std::optional scale, + const std::optional& q_zero_points, + const std::optional& q_scales, + const std::optional& k_zero_points, + const std::optional& k_scales, + const std::optional& v_zero_points, + const std::optional& v_scales, + const bool is_seq_at_dim_2, + at::Tensor& output) { + auto q_et = to_et_arg(q); + auto k_et = to_et_arg(k); + auto v_et = to_et_arg(v); + auto attn_mask_et = to_et_arg>(attn_mask); + auto scale_et = to_et_arg>(scale); + auto q_zero_points_et = to_et_arg>(q_zero_points); + auto q_scales_et = to_et_arg>(q_scales); + auto k_zero_points_et = to_et_arg>(k_zero_points); + auto k_scales_et = to_et_arg>(k_scales); + auto v_zero_points_et = to_et_arg>(v_zero_points); + auto v_scales_et = to_et_arg>(v_scales); + auto output_et = to_et_arg(output); + auto& et_result = custom_quantized_sdpa_out_no_context( + q_et.call(), + k_et.call(), + v_et.call(), + start_pos, + attn_mask_et.call(), + dropout_p, + is_causal, + scale_et.call(), + q_zero_points_et.call(), + q_scales_et.call(), + k_zero_points_et.call(), + k_scales_et.call(), + v_zero_points_et.call(), + v_scales_et.call(), + is_seq_at_dim_2, + output_et.call()); + return copy_et_result_to_out(et_result, output); +} + Tensor& update_cache_out_no_context( const Tensor& value, Tensor& cache, @@ -349,11 +564,23 @@ at::Tensor update_cache_aten( at::Tensor& cache, const int64_t start_pos) { auto output = at::empty({1}); - WRAP_TO_ATEN(update_cache_out_no_context, 3) - (value, cache, start_pos, output); + update_cache_out_aten(value, cache, start_pos, output); return output; } +at::Tensor& update_cache_out_aten( + const at::Tensor& value, + at::Tensor& cache, + const int64_t start_pos, + at::Tensor& output) { + auto value_et = to_et_arg(value); + auto cache_et = to_et_arg(cache); + auto output_et = to_et_arg(output); + auto& et_result = update_cache_out_no_context( + value_et.call(), cache_et.call(), start_pos, output_et.call()); + return copy_et_result_to_out(et_result, output); +} + // Implementations for update_cache_with_indices Tensor& update_cache_with_indices_out_no_context( const Tensor& value, @@ -372,11 +599,81 @@ at::Tensor update_cache_with_indices_aten( const int64_t start_pos, const at::Tensor& indices) { auto output = at::empty({1}); - WRAP_TO_ATEN(update_cache_with_indices_out_no_context, 4) - (value, cache, start_pos, indices, output); + update_cache_with_indices_out_aten(value, cache, start_pos, indices, output); return output; } +at::Tensor& update_cache_with_indices_out_aten( + const at::Tensor& value, + at::Tensor& cache, + const int64_t start_pos, + const at::Tensor& indices, + at::Tensor& output) { + auto value_et = to_et_arg(value); + auto cache_et = to_et_arg(cache); + auto indices_et = to_et_arg(indices); + auto output_et = to_et_arg(output); + auto& et_result = update_cache_with_indices_out_no_context( + value_et.call(), + cache_et.call(), + start_pos, + indices_et.call(), + output_et.call()); + return copy_et_result_to_out(et_result, output); +} + +Tensor& recurrent_gated_delta_rule_out_no_context( + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& g, + const Tensor& beta, + Tensor& recurrent_state, + Tensor& output) { + executorch::aten::RuntimeContext context{}; + return torch::executor::native::recurrent_gated_delta_rule_out( + context, query, key, value, g, beta, recurrent_state, output); +} + +at::Tensor recurrent_gated_delta_rule_aten( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& g, + const at::Tensor& beta, + at::Tensor& recurrent_state) { + auto output = at::empty_like(value); + recurrent_gated_delta_rule_out_aten( + query, key, value, g, beta, recurrent_state, output); + return output; +} + +at::Tensor& recurrent_gated_delta_rule_out_aten( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& g, + const at::Tensor& beta, + at::Tensor& recurrent_state, + at::Tensor& output) { + auto query_et = to_et_arg(query); + auto key_et = to_et_arg(key); + auto value_et = to_et_arg(value); + auto g_et = to_et_arg(g); + auto beta_et = to_et_arg(beta); + auto recurrent_state_et = to_et_arg(recurrent_state); + auto output_et = to_et_arg(output); + auto& et_result = recurrent_gated_delta_rule_out_no_context( + query_et.call(), + key_et.call(), + value_et.call(), + g_et.call(), + beta_et.call(), + recurrent_state_et.call(), + output_et.call()); + return copy_et_result_to_out(et_result, output); +} + } // namespace native } // namespace executor } // namespace torch @@ -410,6 +707,12 @@ TORCH_LIBRARY_FRAGMENT(llama, m) { m.def( "update_cache_with_indices.out(Tensor value, Tensor(a!) cache, " "SymInt start_pos, Tensor indices, *, Tensor(b!) out) -> Tensor(b!)"); + m.def( + "recurrent_gated_delta_rule(Tensor query, Tensor key, Tensor value, Tensor g, " + "Tensor beta, Tensor(a!) recurrent_state) -> Tensor"); + m.def( + "recurrent_gated_delta_rule.out(Tensor query, Tensor key, Tensor value, Tensor g, " + "Tensor beta, Tensor(a!) recurrent_state, *, Tensor(b!) out) -> Tensor(b!)"); m.def( "custom_quantized_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, " "Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, " @@ -430,29 +733,27 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { "sdpa_with_kv_cache", torch::executor::native::sdpa_with_kv_cache_aten); m.impl( "sdpa_with_kv_cache.out", - WRAP_TO_ATEN( - torch::executor::native::sdpa_with_kv_cache_out_no_context, 11)); + torch::executor::native::sdpa_with_kv_cache_out_aten); m.impl("custom_sdpa", torch::executor::native::custom_sdpa_aten); - m.impl( - "custom_sdpa.out", - WRAP_TO_ATEN(torch::executor::native::custom_sdpa_out_no_context, 8)); + m.impl("custom_sdpa.out", torch::executor::native::custom_sdpa_out_aten); m.impl("update_cache", torch::executor::native::update_cache_aten); - m.impl( - "update_cache.out", - WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 3)); + m.impl("update_cache.out", torch::executor::native::update_cache_out_aten); m.impl( "update_cache_with_indices", torch::executor::native::update_cache_with_indices_aten); m.impl( "update_cache_with_indices.out", - WRAP_TO_ATEN( - torch::executor::native::update_cache_with_indices_out_no_context, - 4)); + torch::executor::native::update_cache_with_indices_out_aten); + m.impl( + "recurrent_gated_delta_rule", + torch::executor::native::recurrent_gated_delta_rule_aten); + m.impl( + "recurrent_gated_delta_rule.out", + torch::executor::native::recurrent_gated_delta_rule_out_aten); m.impl( "custom_quantized_sdpa", torch::executor::native::custom_quantized_sdpa_aten); m.impl( "custom_quantized_sdpa.out", - WRAP_TO_ATEN( - torch::executor::native::custom_quantized_sdpa_out_no_context, 15)); + torch::executor::native::custom_quantized_sdpa_out_aten); } diff --git a/extension/llm/custom_ops/op_tile_crop_aot.cpp b/extension/llm/custom_ops/op_tile_crop_aot.cpp index 5aa98ee8d4a..7d89c462e1d 100644 --- a/extension/llm/custom_ops/op_tile_crop_aot.cpp +++ b/extension/llm/custom_ops/op_tile_crop_aot.cpp @@ -16,10 +16,30 @@ namespace torch { namespace executor { namespace native { +namespace { +template +auto to_et_arg(AType&& value) { + return executorch::extension::internal::type_convert( + std::forward(value)); +} + +at::Tensor& copy_et_result_to_out(Tensor& et_result, at::Tensor& out) { + auto converted_result = + executorch::extension::internal::type_convert( + et_result) + .call(); + at::native::resize_output(out, converted_result.sizes()); + out.copy_(converted_result); + return out; +} +} // namespace Tensor& tile_crop_out_no_context(const Tensor& input, int64_t tile_size, Tensor& out); +at::Tensor& +tile_crop_out_aten(const at::Tensor& input, int64_t tile_size, at::Tensor& out); + Tensor& tile_crop_out_no_context(const Tensor& input, int64_t tile_size, Tensor& out) { executorch::aten::RuntimeContext context{}; @@ -28,12 +48,21 @@ tile_crop_out_no_context(const Tensor& input, int64_t tile_size, Tensor& out) { at::Tensor tile_crop_aten(const at::Tensor& input, int64_t tile_size); +at::Tensor& tile_crop_out_aten( + const at::Tensor& input, + int64_t tile_size, + at::Tensor& out) { + auto input_et = to_et_arg(input); + auto out_et = to_et_arg(out); + auto& et_result = + tile_crop_out_no_context(input_et.call(), tile_size, out_et.call()); + return copy_et_result_to_out(et_result, out); +} + at::Tensor tile_crop_aten(const at::Tensor& input, int64_t tile_size) { // max_num_tiles = 4, num_channels = 3. auto output = at::empty({4, 3, tile_size, tile_size}); - - WRAP_TO_ATEN(torch::executor::native::tile_crop_out_no_context, 2) - (input, tile_size, output); + tile_crop_out_aten(input, tile_size, output); return output; } @@ -49,7 +78,5 @@ TORCH_LIBRARY(preprocess, m) { TORCH_LIBRARY_IMPL(preprocess, CompositeExplicitAutograd, m) { m.impl("tile_crop", torch::executor::native::tile_crop_aten); - m.impl( - "tile_crop.out", - WRAP_TO_ATEN(torch::executor::native::tile_crop_out_no_context, 2)); + m.impl("tile_crop.out", torch::executor::native::tile_crop_out_aten); } diff --git a/extension/llm/custom_ops/test_update_cache.py b/extension/llm/custom_ops/test_update_cache.py index 84a349c97f0..7edd273d8b9 100644 --- a/extension/llm/custom_ops/test_update_cache.py +++ b/extension/llm/custom_ops/test_update_cache.py @@ -431,3 +431,155 @@ def test_batched_update_kv_cache_more_updates(self): self._update_and_validate( k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos ) + + +class RecurrentGatedDeltaRuleTest(unittest.TestCase): + def _make_inputs( + self, + batch_size: int = 2, + num_heads: int = 3, + seq_len: int = 4, + k_head_dim: int = 5, + v_head_dim: int = 6, + ): + query = torch.randn(batch_size, num_heads, seq_len, k_head_dim) + key = torch.randn(batch_size, num_heads, seq_len, k_head_dim) + value = torch.randn(batch_size, num_heads, seq_len, v_head_dim) + g = torch.randn(batch_size, num_heads, seq_len) + beta = torch.sigmoid(torch.randn(batch_size, num_heads, seq_len)) + recurrent_state = torch.randn(batch_size, num_heads, k_head_dim, v_head_dim) + return query, key, value, g, beta, recurrent_state + + def _reference_recurrent_gated_delta_rule( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + recurrent_state: torch.Tensor, + ): + state = recurrent_state.clone() + output = torch.zeros_like(value) + + for token in range(query.size(2)): + g_t = g[:, :, token].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, token].unsqueeze(-1) + k_t = key[:, :, token] + v_t = value[:, :, token] + q_t = query[:, :, token] + + state = state * g_t + kv_mem = (state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + state = state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + output[:, :, token] = (state * q_t.unsqueeze(-1)).sum(dim=-2) + + return output, state + + def test_recurrent_gated_delta_rule_matches_reference(self): + torch.manual_seed(0) + + test_cases = ( + (2, 3, 4, 5, 6), + (1, 4, 7, 8, 3), + ) + + for case in test_cases: + with self.subTest(case=case): + ( + query, + key, + value, + g, + beta, + recurrent_state, + ) = self._make_inputs(*case) + + expected_output, expected_state = ( + self._reference_recurrent_gated_delta_rule( + query, + key, + value, + g, + beta, + recurrent_state, + ) + ) + + actual_state = recurrent_state.clone() + actual_output = torch.ops.llama.recurrent_gated_delta_rule( + query, + key, + value, + g, + beta, + actual_state, + ) + + self.assertTrue( + torch.allclose(actual_output, expected_output, atol=1e-5) + ) + self.assertTrue(torch.allclose(actual_state, expected_state, atol=1e-5)) + + def test_recurrent_gated_delta_rule_out_matches_reference(self): + torch.manual_seed(0) + + query, key, value, g, beta, recurrent_state = self._make_inputs() + expected_output, expected_state = self._reference_recurrent_gated_delta_rule( + query, + key, + value, + g, + beta, + recurrent_state, + ) + + actual_state = recurrent_state.clone() + actual_output = torch.empty_like(value) + returned_output = torch.ops.llama.recurrent_gated_delta_rule.out( + query, + key, + value, + g, + beta, + actual_state, + out=actual_output, + ) + + self.assertEqual(returned_output.data_ptr(), actual_output.data_ptr()) + self.assertTrue(torch.allclose(actual_output, expected_output, atol=1e-5)) + self.assertTrue(torch.allclose(actual_state, expected_state, atol=1e-5)) + + def test_recurrent_gated_delta_rule_chunked_matches_full_sequence(self): + torch.manual_seed(0) + + query, key, value, g, beta, recurrent_state = self._make_inputs(seq_len=6) + + full_state = recurrent_state.clone() + full_output = torch.ops.llama.recurrent_gated_delta_rule( + query, + key, + value, + g, + beta, + full_state, + ) + + chunk_state = recurrent_state.clone() + chunk_outputs = [] + for start, end in ((0, 2), (2, 5), (5, 6)): + chunk_outputs.append( + torch.ops.llama.recurrent_gated_delta_rule( + query[:, :, start:end, :], + key[:, :, start:end, :], + value[:, :, start:end, :], + g[:, :, start:end], + beta[:, :, start:end], + chunk_state, + ) + ) + + chunked_output = torch.cat(chunk_outputs, dim=2) + self.assertTrue(torch.allclose(chunked_output, full_output, atol=1e-5)) + self.assertTrue(torch.allclose(chunk_state, full_state, atol=1e-5)) diff --git a/extension/llm/runner/_llm_runner.pyi b/extension/llm/runner/_llm_runner.pyi index 20333578763..271cf1e1540 100644 --- a/extension/llm/runner/_llm_runner.pyi +++ b/extension/llm/runner/_llm_runner.pyi @@ -47,14 +47,15 @@ class GenerationConfig: ... def resolve_max_new_tokens( - self, max_context_len: int, num_prompt_tokens: int + self, max_context_len: int, num_tokens_occupied: int ) -> int: """ Resolve the maximum number of new tokens to generate based on constraints. Args: max_context_len: The maximum context length supported by the model - num_prompt_tokens: The number of tokens in the input prompt + num_tokens_occupied: The number of token positions already occupied + in the context window (e.g. pos after prefill) Returns: The resolved maximum number of new tokens to generate diff --git a/extension/llm/runner/irunner.h b/extension/llm/runner/irunner.h index 0fcce1f37e4..bb7dd767fea 100644 --- a/extension/llm/runner/irunner.h +++ b/extension/llm/runner/irunner.h @@ -10,6 +10,7 @@ #pragma once +#include #include #include #include @@ -65,36 +66,41 @@ struct GenerationConfig { * * This method calculates the maximum number of new tokens that can be * generated considering both seq_len and max_new_tokens constraints, as well - * as the model's maximum context length and the number of tokens in the - * prompt. + * as the model's maximum context length and how many token positions are + * already occupied (e.g. by prior turns and the current prompt). * * @param max_context_len The maximum context length supported by the model - * @param num_prompt_tokens The number of tokens in the input prompt + * @param num_tokens_occupied The number of token positions already occupied + * in the context window (e.g. pos_ after prefill) * @return The resolved maximum number of new tokens to generate */ int32_t resolve_max_new_tokens( - int32_t max_context_len, - int32_t num_prompt_tokens) const { - int32_t result; + int64_t max_context_len, + int64_t num_tokens_occupied) const { + int64_t result; if (seq_len == -1 && max_new_tokens == -1) { - // Both are -1, use max context len minus prompt tokens - result = max_context_len - num_prompt_tokens; + // Both are -1, use max context len minus occupied tokens + result = max_context_len - num_tokens_occupied; } else if (seq_len == -1 && max_new_tokens != -1) { // Only max_new_tokens is specified - result = std::min(max_new_tokens, max_context_len - num_prompt_tokens); + result = std::min( + static_cast(max_new_tokens), + max_context_len - num_tokens_occupied); } else if (seq_len != -1 && max_new_tokens == -1) { // Only seq_len is specified - result = std::min(seq_len, max_context_len) - num_prompt_tokens; + result = std::min(static_cast(seq_len), max_context_len) - + num_tokens_occupied; } else { // Both are specified result = std::min( - std::min(seq_len, max_context_len) - num_prompt_tokens, - max_new_tokens); + std::min(static_cast(seq_len), max_context_len) - + num_tokens_occupied, + static_cast(max_new_tokens)); } // Ensure result is not negative - return std::max(0, result); + return static_cast(std::max(static_cast(0), result)); } }; diff --git a/extension/llm/runner/pybindings.cpp b/extension/llm/runner/pybindings.cpp index ecd49e6341a..3188b5390c4 100644 --- a/extension/llm/runner/pybindings.cpp +++ b/extension/llm/runner/pybindings.cpp @@ -297,7 +297,7 @@ PYBIND11_MODULE(_llm_runner, m) { "resolve_max_new_tokens", &GenerationConfig::resolve_max_new_tokens, py::arg("max_context_len"), - py::arg("num_prompt_tokens"), + py::arg("num_tokens_occupied"), "Resolve the maximum number of new tokens to generate based on constraints") .def("__repr__", [](const GenerationConfig& config) { return " executorch::aten::ArrayRef> BoxedEvalueList>::get() const { @@ -27,5 +31,26 @@ BoxedEvalueList>::get() const { return executorch::aten::ArrayRef>{ unwrapped_vals_, wrapped_vals_.size()}; } + +template <> +Result>> +BoxedEvalueList>::tryGet() const { + for (typename executorch::aten::ArrayRef< + std::optional>::size_type i = 0; + i < wrapped_vals_.size(); + i++) { + if (wrapped_vals_[i] == nullptr) { + unwrapped_vals_[i] = std::nullopt; + continue; + } + auto r = wrapped_vals_[i]->tryToOptional(); + if (!r.ok()) { + return r.error(); + } + unwrapped_vals_[i] = std::move(r.get()); + } + return executorch::aten::ArrayRef>{ + unwrapped_vals_, wrapped_vals_.size()}; +} } // namespace runtime } // namespace executorch diff --git a/runtime/core/evalue.h b/runtime/core/evalue.h index 8d75b1ace97..eed52bb74f7 100644 --- a/runtime/core/evalue.h +++ b/runtime/core/evalue.h @@ -8,6 +8,7 @@ #pragma once #include +#include #include #include @@ -71,6 +72,16 @@ class BoxedEvalueList { */ executorch::aten::ArrayRef get() const; + /** + * Result-returning counterpart of get(). Validates each wrapped EValue's + * tag before materializing; returns Error::InvalidType if any element's + * tag does not match T and Error::InvalidState if any element pointer is + * null. Use this when materializing lists from untrusted .pte data so that + * a malformed program cannot force a process abort inside to() / + * ET_CHECK. + */ + Result> tryGet() const; + /** * Destroys the unwrapped elements without re-dereferencing wrapped_vals_. * This is safe to call during EValue destruction because it does not @@ -107,6 +118,10 @@ template <> executorch::aten::ArrayRef> BoxedEvalueList>::get() const; +template <> +Result>> +BoxedEvalueList>::tryGet() const; + // Aggregate typing system similar to IValue only slimmed down with less // functionality, no dependencies on atomic, and fewer supported types to better // suit embedded systems (ie no intrusive ptr) @@ -193,6 +208,13 @@ struct EValue { return payload.copyable_union.as_int; } + Result tryToInt() const { + if (!isInt()) { + return Error::InvalidType; + } + return payload.copyable_union.as_int; + } + /****** Double Type ******/ /*implicit*/ EValue(double d) : tag(Tag::Double) { payload.copyable_union.as_double = d; @@ -207,6 +229,13 @@ struct EValue { return payload.copyable_union.as_double; } + Result tryToDouble() const { + if (!isDouble()) { + return Error::InvalidType; + } + return payload.copyable_union.as_double; + } + /****** Bool Type ******/ /*implicit*/ EValue(bool b) : tag(Tag::Bool) { payload.copyable_union.as_bool = b; @@ -221,6 +250,13 @@ struct EValue { return payload.copyable_union.as_bool; } + Result tryToBool() const { + if (!isBool()) { + return Error::InvalidType; + } + return payload.copyable_union.as_bool; + } + /****** Scalar Type ******/ /// Construct an EValue using the implicit value of a Scalar. /*implicit*/ EValue(executorch::aten::Scalar s) { @@ -256,6 +292,19 @@ struct EValue { } } + Result tryToScalar() const { + if (isDouble()) { + return executorch::aten::Scalar(payload.copyable_union.as_double); + } + if (isInt()) { + return executorch::aten::Scalar(payload.copyable_union.as_int); + } + if (isBool()) { + return executorch::aten::Scalar(payload.copyable_union.as_bool); + } + return Error::InvalidType; + } + /****** Tensor Type ******/ /*implicit*/ EValue(executorch::aten::Tensor t) : tag(Tag::Tensor) { // When built in aten mode, at::Tensor has a non trivial constructor @@ -305,6 +354,16 @@ struct EValue { return payload.as_tensor; } + // Returns a copy of the Tensor handle (one intrusive_ptr refcount bump in + // ATen mode; free in lean mode). Unlike toTensor()'s const& / & overloads, + // tryToTensor() cannot return a reference — Result wraps by value. + Result tryToTensor() const { + if (!isTensor()) { + return Error::InvalidType; + } + return payload.as_tensor; + } + /****** String Type ******/ /*implicit*/ EValue(executorch::aten::ArrayRef* s) : tag(Tag::String) { ET_CHECK_MSG(s != nullptr, "ArrayRef pointer cannot be null"); @@ -325,6 +384,18 @@ struct EValue { payload.copyable_union.as_string_ptr->size()); } + Result tryToString() const { + if (!isString()) { + return Error::InvalidType; + } + if (payload.copyable_union.as_string_ptr == nullptr) { + return Error::InvalidState; + } + return std::string_view( + payload.copyable_union.as_string_ptr->data(), + payload.copyable_union.as_string_ptr->size()); + } + /****** Int List Type ******/ /*implicit*/ EValue(BoxedEvalueList* i) : tag(Tag::ListInt) { ET_CHECK_MSG( @@ -344,6 +415,16 @@ struct EValue { return (payload.copyable_union.as_int_list_ptr)->get(); } + Result> tryToIntList() const { + if (!isIntList()) { + return Error::InvalidType; + } + if (payload.copyable_union.as_int_list_ptr == nullptr) { + return Error::InvalidState; + } + return (payload.copyable_union.as_int_list_ptr)->tryGet(); + } + /****** Bool List Type ******/ /*implicit*/ EValue(executorch::aten::ArrayRef* b) : tag(Tag::ListBool) { @@ -363,6 +444,16 @@ struct EValue { return *(payload.copyable_union.as_bool_list_ptr); } + Result> tryToBoolList() const { + if (!isBoolList()) { + return Error::InvalidType; + } + if (payload.copyable_union.as_bool_list_ptr == nullptr) { + return Error::InvalidState; + } + return *(payload.copyable_union.as_bool_list_ptr); + } + /****** Double List Type ******/ /*implicit*/ EValue(executorch::aten::ArrayRef* d) : tag(Tag::ListDouble) { @@ -382,6 +473,16 @@ struct EValue { return *(payload.copyable_union.as_double_list_ptr); } + Result> tryToDoubleList() const { + if (!isDoubleList()) { + return Error::InvalidType; + } + if (payload.copyable_union.as_double_list_ptr == nullptr) { + return Error::InvalidState; + } + return *(payload.copyable_union.as_double_list_ptr); + } + /****** Tensor List Type ******/ /*implicit*/ EValue(BoxedEvalueList* t) : tag(Tag::ListTensor) { @@ -402,6 +503,17 @@ struct EValue { return payload.copyable_union.as_tensor_list_ptr->get(); } + Result> tryToTensorList() + const { + if (!isTensorList()) { + return Error::InvalidType; + } + if (payload.copyable_union.as_tensor_list_ptr == nullptr) { + return Error::InvalidState; + } + return payload.copyable_union.as_tensor_list_ptr->tryGet(); + } + /****** List Optional Tensor Type ******/ /*implicit*/ EValue( BoxedEvalueList>* t) @@ -426,6 +538,17 @@ struct EValue { return payload.copyable_union.as_list_optional_tensor_ptr->get(); } + Result>> + tryToListOptionalTensor() const { + if (!isListOptionalTensor()) { + return Error::InvalidType; + } + if (payload.copyable_union.as_list_optional_tensor_ptr == nullptr) { + return Error::InvalidState; + } + return payload.copyable_union.as_list_optional_tensor_ptr->tryGet(); + } + /****** ScalarType Type ******/ executorch::aten::ScalarType toScalarType() const { ET_CHECK_MSG(isInt(), "EValue is not a ScalarType."); @@ -433,6 +556,14 @@ struct EValue { payload.copyable_union.as_int); } + Result tryToScalarType() const { + if (!isInt()) { + return Error::InvalidType; + } + return static_cast( + payload.copyable_union.as_int); + } + /****** MemoryFormat Type ******/ executorch::aten::MemoryFormat toMemoryFormat() const { ET_CHECK_MSG(isInt(), "EValue is not a MemoryFormat."); @@ -440,12 +571,27 @@ struct EValue { payload.copyable_union.as_int); } + Result tryToMemoryFormat() const { + if (!isInt()) { + return Error::InvalidType; + } + return static_cast( + payload.copyable_union.as_int); + } + /****** Layout Type ******/ executorch::aten::Layout toLayout() const { ET_CHECK_MSG(isInt(), "EValue is not a Layout."); return static_cast(payload.copyable_union.as_int); } + Result tryToLayout() const { + if (!isInt()) { + return Error::InvalidType; + } + return static_cast(payload.copyable_union.as_int); + } + /****** Device Type ******/ executorch::aten::Device toDevice() const { ET_CHECK_MSG(isInt(), "EValue is not a Device."); @@ -455,6 +601,16 @@ struct EValue { -1); } + Result tryToDevice() const { + if (!isInt()) { + return Error::InvalidType; + } + return executorch::aten::Device( + static_cast( + payload.copyable_union.as_int), + -1); + } + template T to() &&; template @@ -462,6 +618,15 @@ struct EValue { template typename internal::evalue_to_ref_overload_return::type to() &; + /** + * Result-returning equivalent of `to()`. Tag mismatch returns + * `Error::InvalidType`; a null list/string payload returns + * `Error::InvalidState`. Specializations are defined below via + * `EVALUE_DEFINE_TRY_TO`. + */ + template + Result tryTo() const; + /** * Converts the EValue to an optional object that can represent both T and * an uninitialized state. @@ -474,6 +639,23 @@ struct EValue { return this->to(); } + /** + * Result-returning equivalent of `toOptional()`. None maps to an empty + * optional; any other tag that doesn't match T propagates `tryTo()`'s + * error (`Error::InvalidType`). + */ + template + inline Result> tryToOptional() const { + if (this->isNone()) { + return std::optional(std::nullopt); + } + auto r = this->tryTo(); + if (!r.ok()) { + return r.error(); + } + return std::optional(std::move(r.get())); + } + private: // Pre cond: the payload value has had its destructor called void clearToNone() noexcept { @@ -591,6 +773,59 @@ EVALUE_DEFINE_TO( toListOptionalTensor) #undef EVALUE_DEFINE_TO +#define EVALUE_DEFINE_TRY_TO(T, method_name) \ + template <> \ + inline Result EValue::tryTo() const { \ + return this->method_name(); \ + } + +EVALUE_DEFINE_TRY_TO(executorch::aten::Scalar, tryToScalar) +EVALUE_DEFINE_TRY_TO(int64_t, tryToInt) +EVALUE_DEFINE_TRY_TO(bool, tryToBool) +EVALUE_DEFINE_TRY_TO(double, tryToDouble) +EVALUE_DEFINE_TRY_TO(std::string_view, tryToString) +EVALUE_DEFINE_TRY_TO(executorch::aten::ScalarType, tryToScalarType) +EVALUE_DEFINE_TRY_TO(executorch::aten::MemoryFormat, tryToMemoryFormat) +EVALUE_DEFINE_TRY_TO(executorch::aten::Layout, tryToLayout) +EVALUE_DEFINE_TRY_TO(executorch::aten::Device, tryToDevice) +// Tensor and Optional Tensor +EVALUE_DEFINE_TRY_TO(executorch::aten::Tensor, tryToTensor) +EVALUE_DEFINE_TRY_TO( + std::optional, + tryToOptional) + +// IntList and Optional IntList +EVALUE_DEFINE_TRY_TO(executorch::aten::ArrayRef, tryToIntList) +EVALUE_DEFINE_TRY_TO( + std::optional>, + tryToOptional>) + +// DoubleList and Optional DoubleList +EVALUE_DEFINE_TRY_TO(executorch::aten::ArrayRef, tryToDoubleList) +EVALUE_DEFINE_TRY_TO( + std::optional>, + tryToOptional>) + +// BoolList and Optional BoolList +EVALUE_DEFINE_TRY_TO(executorch::aten::ArrayRef, tryToBoolList) +EVALUE_DEFINE_TRY_TO( + std::optional>, + tryToOptional>) + +// TensorList and Optional TensorList +EVALUE_DEFINE_TRY_TO( + executorch::aten::ArrayRef, + tryToTensorList) +EVALUE_DEFINE_TRY_TO( + std::optional>, + tryToOptional>) + +// List of Optional Tensor +EVALUE_DEFINE_TRY_TO( + executorch::aten::ArrayRef>, + tryToListOptionalTensor) +#undef EVALUE_DEFINE_TRY_TO + template executorch::aten::ArrayRef BoxedEvalueList::get() const { for (typename executorch::aten::ArrayRef::size_type i = 0; @@ -602,6 +837,23 @@ executorch::aten::ArrayRef BoxedEvalueList::get() const { return executorch::aten::ArrayRef{unwrapped_vals_, wrapped_vals_.size()}; } +template +Result> BoxedEvalueList::tryGet() const { + for (typename executorch::aten::ArrayRef::size_type i = 0; + i < wrapped_vals_.size(); + i++) { + if (wrapped_vals_[i] == nullptr) { + return Error::InvalidState; + } + auto r = wrapped_vals_[i]->template tryTo(); + if (!r.ok()) { + return r.error(); + } + unwrapped_vals_[i] = std::move(r.get()); + } + return executorch::aten::ArrayRef{unwrapped_vals_, wrapped_vals_.size()}; +} + } // namespace runtime } // namespace executorch diff --git a/runtime/core/test/evalue_test.cpp b/runtime/core/test/evalue_test.cpp index edf6a1b12c1..1b0b86c1392 100644 --- a/runtime/core/test/evalue_test.cpp +++ b/runtime/core/test/evalue_test.cpp @@ -16,8 +16,12 @@ using namespace ::testing; +using executorch::aten::DeviceType; +using executorch::aten::Layout; +using executorch::aten::MemoryFormat; using executorch::aten::ScalarType; using executorch::runtime::BoxedEvalueList; +using executorch::runtime::Error; using executorch::runtime::EValue; using executorch::runtime::Tag; using executorch::runtime::testing::TensorFactory; @@ -214,6 +218,56 @@ TEST_F(EValueTest, BoxedEvalueList) { EXPECT_EQ(unwrapped[2], 3); } +TEST_F(EValueTest, BoxedEvalueListTryGetSuccess) { + EValue values[3] = { + EValue((int64_t)1), EValue((int64_t)2), EValue((int64_t)3)}; + EValue* values_p[3] = {&values[0], &values[1], &values[2]}; + int64_t storage[3] = {0, 0, 0}; + BoxedEvalueList x{values_p, storage, 3}; + auto result = x.tryGet(); + EXPECT_TRUE(result.ok()); + EXPECT_EQ(result->size(), 3); + EXPECT_EQ((*result)[0], 1); + EXPECT_EQ((*result)[1], 2); + EXPECT_EQ((*result)[2], 3); +} + +TEST_F(EValueTest, BoxedEvalueListTryGetWrongElementTag) { + // Second element is a Double, not an Int; tryGet should reject it rather + // than abort inside to(). + EValue values[3] = {EValue((int64_t)1), EValue(3.14), EValue((int64_t)3)}; + EValue* values_p[3] = {&values[0], &values[1], &values[2]}; + int64_t storage[3] = {0, 0, 0}; + BoxedEvalueList x{values_p, storage, 3}; + auto result = x.tryGet(); + EXPECT_EQ(result.error(), executorch::runtime::Error::InvalidType); +} + +TEST_F(EValueTest, BoxedEvalueListTryGetNullElement) { + // A null value is a malformed program for non-optional lists. + EValue a((int64_t)1); + EValue c((int64_t)3); + EValue* values_p[3] = {&a, nullptr, &c}; + int64_t storage[3] = {0, 0, 0}; + BoxedEvalueList x{values_p, storage, 3}; + auto result = x.tryGet(); + EXPECT_EQ(result.error(), executorch::runtime::Error::InvalidState); +} + +TEST_F(EValueTest, BoxedEvalueListTryGetOptionalTensorNullIsNone) { + // For optional, null value is valid. + EValue a; + EValue* values_p[2] = {&a, nullptr}; + std::optional storage[2]; + BoxedEvalueList> x{ + values_p, storage, 2}; + auto result = x.tryGet(); + EXPECT_TRUE(result.ok()); + EXPECT_EQ(result->size(), 2); + EXPECT_FALSE((*result)[0].has_value()); + EXPECT_FALSE((*result)[1].has_value()); +} + TEST_F(EValueTest, toOptionalTensorList) { // create list, empty evalue ctor gets tag::None EValue values[2] = {EValue(), EValue()}; @@ -417,3 +471,116 @@ TEST_F(EValueTest, toListOptionalTensorNullPointerCheck) { EXPECT_TRUE(e.isListOptionalTensor()); ET_EXPECT_DEATH({ e.toListOptionalTensor(); }, "pointer is null"); } + +// Per-type tryTo* coverage. +// For each type: +// - success and failure for named method tryTo[Int/Double/Bool/Tensor/..] +// - success and failure for templated tryTo() specialization + +TEST_F(EValueTest, TryToInt) { + EValue e_int(static_cast(42)); + EValue e_mismatch(3.14); + EXPECT_EQ(e_int.tryToInt().get(), 42); + EXPECT_EQ(e_mismatch.tryToInt().error(), Error::InvalidType); + EXPECT_EQ(e_int.tryTo().get(), 42); + EXPECT_EQ(e_mismatch.tryTo().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToDouble) { + EValue e_double(3.14); + EValue e_mismatch(static_cast(42)); + EXPECT_DOUBLE_EQ(e_double.tryToDouble().get(), 3.14); + EXPECT_EQ(e_mismatch.tryToDouble().error(), Error::InvalidType); + EXPECT_DOUBLE_EQ(e_double.tryTo().get(), 3.14); + EXPECT_EQ(e_mismatch.tryTo().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToBool) { + EValue e_bool(true); + EValue e_mismatch(static_cast(42)); + EXPECT_EQ(e_bool.tryToBool().get(), true); + EXPECT_EQ(e_mismatch.tryToBool().error(), Error::InvalidType); + EXPECT_EQ(e_bool.tryTo().get(), true); + EXPECT_EQ(e_mismatch.tryTo().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToTensor) { + TensorFactory tf; + EValue e_tensor(tf.ones({3, 2})); + EValue e_mismatch(static_cast(42)); + EXPECT_EQ(e_tensor.tryToTensor()->numel(), 6); + EXPECT_EQ(e_mismatch.tryToTensor().error(), Error::InvalidType); + EXPECT_EQ(e_tensor.tryTo()->numel(), 6); + EXPECT_EQ( + e_mismatch.tryTo().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToOptionalTensor) { + TensorFactory tf; + EValue e_tensor(tf.ones({3, 2})); + EValue e_none; + EValue e_mismatch(static_cast(42)); + // Named tryToOptional: value, None, mismatch. + auto r_val = e_tensor.tryToOptional(); + EXPECT_TRUE(r_val->has_value()); + EXPECT_EQ(r_val->value().numel(), 6); + EXPECT_FALSE(e_none.tryToOptional()->has_value()); + EXPECT_EQ( + e_mismatch.tryToOptional().error(), + Error::InvalidType); + // Templated tryTo>: None path. + EXPECT_FALSE( + e_none.tryTo>()->has_value()); +} + +TEST_F(EValueTest, TryToScalar) { + EValue e_int(static_cast(7)); + EValue e_double(2.5); + EValue e_bool(true); + EValue e_none; + EXPECT_EQ(e_int.tryToScalar()->to(), 7); + EXPECT_DOUBLE_EQ(e_double.tryToScalar()->to(), 2.5); + EXPECT_EQ(e_bool.tryToScalar()->to(), true); + // None is neither Int/Double/Bool. + EXPECT_EQ(e_none.tryToScalar().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToScalarType) { + EValue e(static_cast(ScalarType::Float)); + EValue e_mismatch(3.14); + EXPECT_EQ(e.tryToScalarType().get(), ScalarType::Float); + EXPECT_EQ(e_mismatch.tryToScalarType().error(), Error::InvalidType); + EXPECT_EQ(e.tryTo().get(), ScalarType::Float); + EXPECT_EQ(e_mismatch.tryTo().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToMemoryFormat) { + EValue e(static_cast(MemoryFormat::Contiguous)); + EValue e_mismatch(3.14); + EXPECT_EQ(e.tryToMemoryFormat().get(), MemoryFormat::Contiguous); + EXPECT_EQ(e_mismatch.tryToMemoryFormat().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToLayout) { + EValue e(static_cast(Layout::Strided)); + EValue e_mismatch(3.14); + EXPECT_EQ(e.tryToLayout().get(), Layout::Strided); + EXPECT_EQ(e_mismatch.tryToLayout().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToDevice) { + EValue e(static_cast(DeviceType::CPU)); + EValue e_mismatch(3.14); + EXPECT_EQ(e.tryToDevice().get().type(), DeviceType::CPU); + EXPECT_EQ(e_mismatch.tryToDevice().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToTensorList) { + EValue e(static_cast(42)); + EXPECT_EQ(e.tryToTensorList().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToListOptionalTensor) { + EValue e(static_cast(42)); + EXPECT_EQ(e.tryToListOptionalTensor().error(), Error::InvalidType); +} diff --git a/third-party/CMakeLists.txt b/third-party/CMakeLists.txt index 904213a1fec..fdc11d8c782 100644 --- a/third-party/CMakeLists.txt +++ b/third-party/CMakeLists.txt @@ -24,8 +24,11 @@ endif() if(WIN32) set(_executorch_external_project_additional_args) else() - # Always use Make to avoid needing to codesign flatc if the project is using Xcode. - set(_executorch_external_project_additional_args CMAKE_GENERATOR "Unix Makefiles") + # Always use Make to avoid needing to codesign flatc if the project is using + # Xcode. + set(_executorch_external_project_additional_args CMAKE_GENERATOR + "Unix Makefiles" + ) endif() # We use ExternalProject to build flatc from source to force it target the host. @@ -35,94 +38,120 @@ ExternalProject_Add( PREFIX ${CMAKE_CURRENT_BINARY_DIR}/flatc_ep BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/flatc_ep/src/build SOURCE_DIR ${PROJECT_SOURCE_DIR}/third-party/flatbuffers - CMAKE_ARGS -DFLATBUFFERS_BUILD_FLATC=ON - -DFLATBUFFERS_INSTALL=ON - -DFLATBUFFERS_BUILD_FLATHASH=OFF - -DFLATBUFFERS_BUILD_FLATLIB=OFF - -DFLATBUFFERS_BUILD_TESTS=OFF - -DCMAKE_INSTALL_PREFIX:PATH= - -DCMAKE_CXX_FLAGS="-DFLATBUFFERS_MAX_ALIGNMENT=${EXECUTORCH_FLATBUFFERS_MAX_ALIGNMENT}" - # Unset the toolchain to build for the host instead of the toolchain set for the project. - -DCMAKE_TOOLCHAIN_FILE= - # If building for iOS, "unset" these variables to rely on the host (macOS) defaults. - $<$,$>>:-DCMAKE_OSX_SYSROOT=> - -DCMAKE_OSX_DEPLOYMENT_TARGET:STRING=${CMAKE_OSX_DEPLOYMENT_TARGET} + CMAKE_ARGS + -DFLATBUFFERS_BUILD_FLATC=ON + -DFLATBUFFERS_INSTALL=ON + -DFLATBUFFERS_BUILD_FLATHASH=OFF + -DFLATBUFFERS_BUILD_FLATLIB=OFF + -DFLATBUFFERS_BUILD_TESTS=OFF + -DCMAKE_INSTALL_PREFIX:PATH= + -DCMAKE_CXX_FLAGS="-DFLATBUFFERS_MAX_ALIGNMENT=${EXECUTORCH_FLATBUFFERS_MAX_ALIGNMENT}" + # Unset the toolchain to build for the host instead of the toolchain set for + # the project. + -DCMAKE_TOOLCHAIN_FILE= + # If building for iOS, "unset" these variables to rely on the host (macOS) + # defaults. + $<$,$>>:-DCMAKE_OSX_SYSROOT=> + -DCMAKE_OSX_DEPLOYMENT_TARGET:STRING=${CMAKE_OSX_DEPLOYMENT_TARGET} BUILD_BYPRODUCTS /bin/flatc - ${_executorch_external_project_additional_args} + ${_executorch_external_project_additional_args} ) ExternalProject_Get_Property(flatbuffers_ep INSTALL_DIR) add_executable(flatc IMPORTED GLOBAL) add_dependencies(flatc flatbuffers_ep) if(WIN32 AND NOT CMAKE_CROSSCOMPILING) - # flatbuffers does not use CMAKE_BUILD_TYPE. Internally, the build forces Release - # config, but from CMake's perspective the build type is always Debug. - set_target_properties(flatc PROPERTIES IMPORTED_LOCATION ${INSTALL_DIR}/bin/flatc.exe) + # flatbuffers does not use CMAKE_BUILD_TYPE. Internally, the build forces + # Release config, but from CMake's perspective the build type is always Debug. + set_target_properties( + flatc PROPERTIES IMPORTED_LOCATION ${INSTALL_DIR}/bin/flatc.exe + ) else() - set_target_properties(flatc PROPERTIES IMPORTED_LOCATION ${INSTALL_DIR}/bin/flatc) + set_target_properties( + flatc PROPERTIES IMPORTED_LOCATION ${INSTALL_DIR}/bin/flatc + ) endif() # TODO: re-enable once flatbuffers is added as a subdirectory. -# set(FLATBUFFERS_BUILD_FLATC OFF) -# set(FLATBUFFERS_INSTALL OFF) -# set(FLATBUFFERS_BUILD_FLATHASH OFF) -# set(FLATBUFFERS_BUILD_FLATLIB OFF) +# set(FLATBUFFERS_BUILD_FLATC OFF) set(FLATBUFFERS_INSTALL OFF) +# set(FLATBUFFERS_BUILD_FLATHASH OFF) set(FLATBUFFERS_BUILD_FLATLIB OFF) # set(FLATBUFFERS_BUILD_TESTS OFF) # MARK: - flatcc if(WIN32) # For some reason, when configuring the external project during build - # CMAKE_C_SIMULATE_ID is set to MSVC, but CMAKE_CXX_SIMULATE_ID is not set. - # To make sure the external project is configured correctly, set it explicitly + # CMAKE_C_SIMULATE_ID is set to MSVC, but CMAKE_CXX_SIMULATE_ID is not set. To + # make sure the external project is configured correctly, set it explicitly # here. set(_flatcc_extra_cmake_args -DCMAKE_CXX_SIMULATE_ID=MSVC) else() set(_flatcc_extra_cmake_args) endif() -# Similar to flatbuffers, we want to build flatcc for the host. See inline comments -# in the flatbuffers ExternalProject_Add for more details. +# Similar to flatbuffers, we want to build flatcc for the host. See inline +# comments in the flatbuffers ExternalProject_Add for more details. ExternalProject_Add( flatcc_ep PREFIX ${CMAKE_CURRENT_BINARY_DIR}/flatcc_ep SOURCE_DIR ${PROJECT_SOURCE_DIR}/third-party/flatcc BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/flatcc_ep/src/build - CMAKE_ARGS -DFLATCC_RTONLY=OFF - -DFLATCC_TEST=OFF - -DFLATCC_REFLECTION=OFF - -DFLATCC_DEBUG_CLANG_SANITIZE=OFF - -DFLATCC_ALLOW_WERROR=OFF - -DFLATCC_INSTALL=ON - -DCMAKE_POLICY_VERSION_MINIMUM=3.5 - -DCMAKE_INSTALL_PREFIX:PATH= - -DCMAKE_POSITION_INDEPENDENT_CODE=ON - -DCMAKE_TOOLCHAIN_FILE= - $<$,$>>:-DCMAKE_OSX_SYSROOT=> - -DCMAKE_OSX_DEPLOYMENT_TARGET:STRING=${CMAKE_OSX_DEPLOYMENT_TARGET} - ${_flatcc_extra_cmake_args} + CMAKE_ARGS + -DFLATCC_RTONLY=OFF + -DFLATCC_TEST=OFF + -DFLATCC_REFLECTION=OFF + -DFLATCC_DEBUG_CLANG_SANITIZE=OFF + -DFLATCC_ALLOW_WERROR=OFF + -DFLATCC_INSTALL=ON + -DCMAKE_POLICY_VERSION_MINIMUM=3.5 + -DCMAKE_INSTALL_PREFIX:PATH= + -DCMAKE_POSITION_INDEPENDENT_CODE=ON + -DCMAKE_TOOLCHAIN_FILE= + $<$,$>>:-DCMAKE_OSX_SYSROOT=> + -DCMAKE_OSX_DEPLOYMENT_TARGET:STRING=${CMAKE_OSX_DEPLOYMENT_TARGET} + ${_flatcc_extra_cmake_args} BUILD_BYPRODUCTS /bin/flatcc - {_executorch_external_project_additional_args} + ${_executorch_external_project_additional_args} ) file(REMOVE_RECURSE ${PROJECT_SOURCE_DIR}/third-party/flatcc/lib) ExternalProject_Get_Property(flatcc_ep INSTALL_DIR) add_executable(flatcc_cli IMPORTED GLOBAL) add_dependencies(flatcc_cli flatcc_ep) if(WIN32 AND NOT CMAKE_CROSSCOMPILING) - set_target_properties(flatcc_cli PROPERTIES IMPORTED_LOCATION ${INSTALL_DIR}/bin/flatcc.exe) + set_target_properties( + flatcc_cli PROPERTIES IMPORTED_LOCATION ${INSTALL_DIR}/bin/flatcc.exe + ) else() - set_target_properties(flatcc_cli PROPERTIES IMPORTED_LOCATION ${INSTALL_DIR}/bin/flatcc) + set_target_properties( + flatcc_cli PROPERTIES IMPORTED_LOCATION ${INSTALL_DIR}/bin/flatcc + ) endif() -set(FLATCC_RTONLY ON CACHE BOOL "") -set(FLATCC_TEST OFF CACHE BOOL "") -set(FLATCC_REFLECTION OFF CACHE BOOL "") -set(FLATCC_DEBUG_CLANG_SANITIZE OFF CACHE BOOL "") -set(FLATCC_INSTALL OFF CACHE BOOL "") +set(FLATCC_RTONLY + ON + CACHE BOOL "" +) +set(FLATCC_TEST + OFF + CACHE BOOL "" +) +set(FLATCC_REFLECTION + OFF + CACHE BOOL "" +) +set(FLATCC_DEBUG_CLANG_SANITIZE + OFF + CACHE BOOL "" +) +set(FLATCC_INSTALL + OFF + CACHE BOOL "" +) add_subdirectory(flatcc) # Unfortunately flatcc writes libs directly in to the source tree [1]. So to # ensure the target lib is created last, force flatcc_cli to build first. # -# [1] https://github.com/dvidelabs/flatcc/blob/896db54787e8b730a6be482c69324751f3f5f117/CMakeLists.txt#L168 +# [1] +# https://github.com/dvidelabs/flatcc/blob/896db54787e8b730a6be482c69324751f3f5f117/CMakeLists.txt#L168 add_dependencies(flatccrt flatcc_cli) # Fix for "relocation R_X86_64_32 against `.rodata' can not be used when making # a shared object; recompile with -fPIC" when building on some x86 linux @@ -130,7 +159,4 @@ add_dependencies(flatccrt flatcc_cli) # # Learn more: https://github.com/pytorch/executorch/pull/2467 set_property(TARGET flatccrt PROPERTY POSITION_INDEPENDENT_CODE ON) -install( - TARGETS flatccrt - DESTINATION ${CMAKE_BINARY_DIR}/lib -) +install(TARGETS flatccrt DESTINATION ${CMAKE_BINARY_DIR}/lib)