From e7b38a30498fc14a410c55767258a6cbf2bbd9ca Mon Sep 17 00:00:00 2001 From: Xingguo Li <100689130+xingguo01@users.noreply.github.com> Date: Thu, 23 Apr 2026 18:08:23 +0100 Subject: [PATCH 01/17] Cortex-M backend: Support standalone clamp-type activations (#18767) - Add support for quantized clamp-type activations in the Cortex-M pipeline by canonicalizing relu/hardtanh/clamp to quantized aten.clamp.default for standalone int8 paths - Extend activation fusion to cover max_pool2d. @freddan80 @per @zingo @oscarandersson8218 @digantdesai @Sebastian-Larsson @AdrianLundell @psiddh cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell Signed-off-by: Xingguo Li --- backends/cortex_m/passes/__init__.py | 1 + .../cortex_m/passes/activation_fusion_pass.py | 70 ++++--- .../cortex_m/passes/cortex_m_pass_manager.py | 2 + backends/cortex_m/passes/passes_utils.py | 51 +++++ .../passes/quantized_clamp_activation_pass.py | 129 ++++++++++++ .../quantizer/quantization_configs.py | 40 ++++ .../cortex_m/quantizer/quantizer_support.py | 24 +++ backends/cortex_m/test/build_test_runner.sh | 1 + .../cortex_m/test/misc/test_portable_int8.py | 6 - .../cortex_m/test/models/test_nn_modules.py | 2 +- backends/cortex_m/test/ops/test_activation.py | 186 +++++++++++++++++- .../cortex_m/test/ops/test_conv_transpose.py | 77 +++++++- 12 files changed, 551 insertions(+), 38 deletions(-) create mode 100644 backends/cortex_m/passes/quantized_clamp_activation_pass.py diff --git a/backends/cortex_m/passes/__init__.py b/backends/cortex_m/passes/__init__.py index 3ef5fc02adb..19665f37083 100644 --- a/backends/cortex_m/passes/__init__.py +++ b/backends/cortex_m/passes/__init__.py @@ -8,6 +8,7 @@ from .convert_to_cortex_m_pass import ConvertToCortexMPass # noqa from .decompose_hardswish_pass import DecomposeHardswishPass # noqa from .decompose_mean_pass import DecomposeMeanPass # noqa +from .quantized_clamp_activation_pass import QuantizedClampActivationPass # noqa from .quantized_op_fusion_pass import QuantizedOpFusionPass # noqa from .replace_quant_nodes_pass import ReplaceQuantNodesPass # noqa from .cortex_m_pass_manager import CortexMPassManager # noqa # usort: skip diff --git a/backends/cortex_m/passes/activation_fusion_pass.py b/backends/cortex_m/passes/activation_fusion_pass.py index a53c065aaa4..ff61f3493dd 100644 --- a/backends/cortex_m/passes/activation_fusion_pass.py +++ b/backends/cortex_m/passes/activation_fusion_pass.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -8,7 +8,10 @@ import executorch.backends.cortex_m.ops.operators # noqa: F401 from executorch.backends.arm._passes.quant_args import QuantArgs -from executorch.backends.cortex_m.passes.passes_utils import quantize_val +from executorch.backends.cortex_m.passes.passes_utils import ( + get_activation_bounds, + quantize_val, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -23,7 +26,7 @@ class ActivationFusionPass(ExportPass): """Fuse activations into preceding Cortex-M quantized operators. Supported activation patterns: - q-> [conv2d, linear] -> [relu, hardtanh, hardsigmoid] -> dq + q-> [conv2d, linear, max_pool2d] -> [relu, hardtanh, hardsigmoid, clamp] -> dq Fusing works by clamping the quantized output range (and zero-point when required) of the preceding Cortex-M operator, then removing the activation @@ -37,10 +40,17 @@ class ActivationFusionPass(ExportPass): exir_ops.edge.aten.clamp.default, } + MAX_POOL_OPS = { + exir_ops.edge.aten.max_pool2d.default, + exir_ops.edge.aten.max_pool2d_with_indices.default, + } + FUSE_OPS = { exir_ops.edge.aten.linear.default, exir_ops.edge.aten.convolution.default, exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.max_pool2d.default, + exir_ops.edge.aten.max_pool2d_with_indices.default, } def _get_validated_qparams(self, node, input_node): @@ -63,30 +73,38 @@ def _get_validated_qparams(self, node, input_node): ) return None - match node.target: - case exir_ops.edge.aten.relu.default: - quantized_min_val = quantize_val(0, scale, zp, qmin, qmax) - quantized_max_val = qmax - case exir_ops.edge.aten.hardtanh.default: - quantized_min_val = quantize_val(node.args[1], scale, zp, qmin, qmax) - quantized_max_val = quantize_val(node.args[2], scale, zp, qmin, qmax) - case exir_ops.edge.aten.hardsigmoid.default: - quantized_min_val = quantize_val(0, scale, zp, qmin, qmax) - quantized_max_val = quantize_val(1, scale, zp, qmin, qmax) - case exir_ops.edge.aten.clamp.default: - quantized_min_val = ( - quantize_val(node.args[1], scale, zp, qmin, qmax) - if node.args[1] is not None - else qmin - ) - # Last arg is removed if none, so check length of args here - quantized_max_val = ( - quantize_val(node.args[2], scale, zp, qmin, qmax) - if len(node.args) == 3 - else qmax + bounds = get_activation_bounds(node) + if bounds is None: + logger.warning( + "Cannot fuse activation %s because bounds are not compile-time scalars.", + node.name, + ) + return None + min_val, max_val = bounds + + quantized_min_val = ( + quantize_val(min_val, scale, zp, qmin, qmax) + if min_val is not None + else qmin + ) + quantized_max_val = ( + quantize_val(max_val, scale, zp, qmin, qmax) + if max_val is not None + else qmax + ) + + if input_node.target in self.MAX_POOL_OPS: + if node.target == exir_ops.edge.aten.hardsigmoid.default: + logger.warning( + "Cannot fuse hardsigmoid %s after max_pool2d because max_pool2d requires matching input/output qparams.", + node.name, ) - case _: - raise RuntimeError(f"Unexpected target {node.target}.") + return None + # Max-pool keeps scale and zero-point unchanged and lowers fused + # activation bounds separately, so only qmin/qmax need updating here. + qparams_dict["qmin"] = int(quantized_min_val) + qparams_dict["qmax"] = int(quantized_max_val) + return qparams_dict # If the minimal quantized value is larger than the qmin, it means that the quantized range contains # invalid values [qmin, ..., quantized_min_val-1], indicating bad quantization parameters. diff --git a/backends/cortex_m/passes/cortex_m_pass_manager.py b/backends/cortex_m/passes/cortex_m_pass_manager.py index 9fef167ef09..074eb6118d0 100644 --- a/backends/cortex_m/passes/cortex_m_pass_manager.py +++ b/backends/cortex_m/passes/cortex_m_pass_manager.py @@ -28,6 +28,7 @@ from .convert_to_cortex_m_pass import ConvertToCortexMPass from .decompose_hardswish_pass import DecomposeHardswishPass from .decompose_mean_pass import DecomposeMeanPass +from .quantized_clamp_activation_pass import QuantizedClampActivationPass from .quantized_op_fusion_pass import QuantizedOpFusionPass from .replace_quant_nodes_pass import ReplaceQuantNodesPass @@ -42,6 +43,7 @@ class CortexMPassManager(PassManager): ReplaceScalarWithTensorArgPass, ReplaceQuantNodesPass, ActivationFusionPass, + QuantizedClampActivationPass, DecomposeHardswishPass, QuantizedOpFusionPass, ConvertToCortexMPass, diff --git a/backends/cortex_m/passes/passes_utils.py b/backends/cortex_m/passes/passes_utils.py index a6f68022430..fcbfa301b06 100644 --- a/backends/cortex_m/passes/passes_utils.py +++ b/backends/cortex_m/passes/passes_utils.py @@ -6,6 +6,7 @@ # LICENSE file in the root directory of this source tree. import math +from typing import Any import torch @@ -21,6 +22,56 @@ def quantize_val(val, scale, zp, qmin, qmax): return float(min(max(torch.round(torch.Tensor([val / scale + zp])), qmin), qmax)) +def extract_constant_scalar(arg: Any) -> float | None: + if arg is None: + return None + if isinstance(arg, (int, float)): + return float(arg) + if isinstance(arg, Node): + if arg.op == "call_function" and arg.target in { + exir_ops.edge.aten.full_like.default, + exir_ops.edge.aten.full.default, + torch.ops.aten.full_like.default, + torch.ops.aten.full.default, + }: + fill_arg = arg.args[1] if len(arg.args) > 1 else None + return extract_constant_scalar(fill_arg) + val = arg.meta.get("val") + if val is None: + return None + return extract_constant_scalar(val) + return None + + +def get_activation_bounds(node: Node) -> tuple[float | None, float | None] | None: + bounds: tuple[float | None, float | None] + match node.target: + case exir_ops.edge.aten.relu.default | exir_ops.edge.aten.relu_.default: + bounds = (0.0, None) + case exir_ops.edge.aten.hardsigmoid.default: + bounds = (0.0, 1.0) + case exir_ops.edge.aten.hardtanh.default | exir_ops.edge.aten.hardtanh_.default: + bounds = ( + extract_constant_scalar(node.args[1]), + extract_constant_scalar(node.args[2]), + ) + case exir_ops.edge.aten.clamp.default | exir_ops.edge.aten.clamp.Tensor: + bounds = ( + extract_constant_scalar(node.args[1]) if len(node.args) > 1 else None, + extract_constant_scalar(node.args[2]) if len(node.args) > 2 else None, + ) + case _: + return None + + min_val, max_val = bounds + if len(node.args) > 1 and min_val is None and node.args[1] is not None: + return None + if len(node.args) > 2 and max_val is None and node.args[2] is not None: + return None + + return bounds + + def dequantize_per_tensor_cmsis( qtensor: torch.Tensor, zero_point: int, multiplier: int, shift: int ) -> torch.Tensor: diff --git a/backends/cortex_m/passes/quantized_clamp_activation_pass.py b/backends/cortex_m/passes/quantized_clamp_activation_pass.py new file mode 100644 index 00000000000..2ba003dbc01 --- /dev/null +++ b/backends/cortex_m/passes/quantized_clamp_activation_pass.py @@ -0,0 +1,129 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Any + +import torch +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_output_qparams, +) +from executorch.backends.cortex_m.passes.passes_utils import ( + get_activation_bounds, + quantize_val, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_manager import PassResult + +logger = logging.getLogger(__name__) + + +class QuantizedClampActivationPass(ExportPass): + """Canonicalize remaining clamp-like activations on quantized tensors. + + This pass runs after activation fusion, so any remaining relu/hardtanh/clamp + still needs to execute in the quantized domain. It rewrites relu and + hardtanh variants to `aten.clamp.default` and quantizes the clamp bounds so + the portable kernel consumes and produces int8 tensors. + """ + + TARGETS = { + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.relu_.default, + exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.hardtanh_.default, + exir_ops.edge.aten.clamp.default, + exir_ops.edge.aten.clamp.Tensor, + } + + def _get_quantized_bounds( + self, node: Node, qparams_dict: dict[str, Any] + ) -> tuple[int | None, int | None] | None: + qmin = qparams_dict["qmin"] + qmax = qparams_dict["qmax"] + scale = qparams_dict["scale"] + zp = qparams_dict["zp"] + + bounds = get_activation_bounds(node) + if bounds is None: + logger.warning( + "Cannot rewrite %s because bounds are not compile-time scalars.", + node.name, + ) + return None + min_val, max_val = bounds + + quantized_min = ( + int(quantize_val(min_val, scale, zp, qmin, qmax)) + if min_val is not None + else None + ) + quantized_max = ( + int(quantize_val(max_val, scale, zp, qmin, qmax)) + if max_val is not None + else None + ) + return quantized_min, quantized_max + + def _is_quantized_int8_activation(self, node: Node) -> bool: + input_node = node.args[0] if len(node.args) > 0 else None + if not isinstance(input_node, Node): + return False + try: + tensor = get_first_fake_tensor(input_node) + except Exception: + return False + if tensor is None or tensor.dtype != torch.int8: + return False + + try: + qparams_dict = get_output_qparams(node)[0]._asdict() + except (ValueError, KeyError): + logger.warning( + "Cannot quantize clamp bounds for %s without output qparams.", + node.name, + ) + return False + + scale = qparams_dict["scale"] + zp = qparams_dict["zp"] + if not isinstance(scale, float) or not isinstance(zp, int): + logger.warning( + "Cannot quantize clamp bounds for %s with non per-tensor qparams.", + node.name, + ) + return False + + return True + + def call(self, graph_module: GraphModule) -> PassResult: + modified = False + + for node in list(graph_module.graph.nodes): + if node.op != "call_function" or node.target not in self.TARGETS: + continue + if not self._is_quantized_int8_activation(node): + continue + + qparams_dict = get_output_qparams(node)[0]._asdict() + + quantized_bounds = self._get_quantized_bounds(node, qparams_dict) + if quantized_bounds is None: + continue + + quantized_min, quantized_max = quantized_bounds + node.target = exir_ops.edge.aten.clamp.default + node.args = (node.args[0], quantized_min, quantized_max) + modified = True + + if modified: + graph_module = super().call(graph_module).graph_module + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + + return PassResult(graph_module, modified) diff --git a/backends/cortex_m/quantizer/quantization_configs.py b/backends/cortex_m/quantizer/quantization_configs.py index 9bc13c05e9d..0f10bd6afef 100644 --- a/backends/cortex_m/quantizer/quantization_configs.py +++ b/backends/cortex_m/quantizer/quantization_configs.py @@ -2,6 +2,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import operator from typing import Any, Callable import torch @@ -86,10 +87,45 @@ torch.ops.aten.max_pool2d_with_indices.default, } +POOL_FUSED_ACTIVATION_TARGETS = { + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, + torch.ops.aten.clamp.default, + torch.ops.aten.clamp_.default, +} + class CortexMQuantizationConfig(QuantizationConfig): """Configures quantization, while enforcing cortex-m specific constraints.""" + @staticmethod + def _get_shared_pool_input(node: Node | None) -> Node | None: + if node is None or len(node.args) == 0: + return None + + input_node = node.args[0] + if not isinstance(input_node, Node): + return None + + if input_node.target in POOL_SHARE_OUTPUT_TARGETS: + if len(input_node.args) > 0 and isinstance(input_node.args[0], Node): + return input_node.args[0] + return None + + if input_node.target == operator.getitem and len(input_node.args) > 0: + pool_node = input_node.args[0] + if ( + isinstance(pool_node, Node) + and pool_node.target in POOL_SHARE_OUTPUT_TARGETS + and len(pool_node.args) > 0 + and isinstance(pool_node.args[0], Node) + ): + return pool_node.args[0] + + return None + def get_input_act_qspec( self, node: Node | None = None, input_node: Node | None = None ) -> QuantizationSpecBase | None: @@ -117,6 +153,10 @@ def get_output_act_qspec( if isinstance(input_node, Node): return SharedQuantizationSpec((input_node, node)) return super().get_output_act_qspec() + if node is not None and node.target in POOL_FUSED_ACTIVATION_TARGETS: + shared_pool_input = self._get_shared_pool_input(node) + if shared_pool_input is not None: + return SharedQuantizationSpec(shared_pool_input) return super().get_output_act_qspec() def get_weight_qspec(self, node: Node | None = None) -> QuantizationSpecBase | None: diff --git a/backends/cortex_m/quantizer/quantizer_support.py b/backends/cortex_m/quantizer/quantizer_support.py index 2cf0483f74b..3dfbb67638a 100644 --- a/backends/cortex_m/quantizer/quantizer_support.py +++ b/backends/cortex_m/quantizer/quantizer_support.py @@ -122,7 +122,31 @@ POOL_OP_PATTERNS = { (torch.ops.aten.avg_pool2d.default,): CortexMAvgPool2DCheck, (torch.ops.aten.max_pool2d.default,): CortexMMaxPool2DCheck, + ( + torch.ops.aten.max_pool2d.default, + torch.ops.aten.relu.default, + ): CortexMMaxPool2DCheck, + ( + torch.ops.aten.max_pool2d.default, + torch.ops.aten.hardtanh.default, + ): CortexMMaxPool2DCheck, + ( + torch.ops.aten.max_pool2d.default, + torch.ops.aten.clamp.default, + ): CortexMMaxPool2DCheck, (torch.ops.aten.max_pool2d_with_indices.default,): CortexMMaxPool2DCheck, + ( + torch.ops.aten.max_pool2d_with_indices.default, + torch.ops.aten.relu.default, + ): CortexMMaxPool2DCheck, + ( + torch.ops.aten.max_pool2d_with_indices.default, + torch.ops.aten.hardtanh.default, + ): CortexMMaxPool2DCheck, + ( + torch.ops.aten.max_pool2d_with_indices.default, + torch.ops.aten.clamp.default, + ): CortexMMaxPool2DCheck, } BMM_OP_PATTERNS = { diff --git a/backends/cortex_m/test/build_test_runner.sh b/backends/cortex_m/test/build_test_runner.sh index 6ac9aa55e73..2505f83c9da 100755 --- a/backends/cortex_m/test/build_test_runner.sh +++ b/backends/cortex_m/test/build_test_runner.sh @@ -21,6 +21,7 @@ build_root_test_dir="${et_root_dir}/arm_test/arm_semihosting_executor_runner_cor select_ops_list="\ aten::add.out,\ +aten::clamp.out,\ aten::mul.out,\ aten::convolution.out,\ dim_order_ops::_clone_dim_order.out,\ diff --git a/backends/cortex_m/test/misc/test_portable_int8.py b/backends/cortex_m/test/misc/test_portable_int8.py index 82b719230eb..4e3b5f41561 100644 --- a/backends/cortex_m/test/misc/test_portable_int8.py +++ b/backends/cortex_m/test/misc/test_portable_int8.py @@ -662,12 +662,6 @@ def _quantize_and_export( xfails: dict[str, xfail_type] = { "contiguous": "MLETORCH-1863: Contiguos no-op is removed in to-edge, leading to unnecessary Q-DQ-Q-DQ chain.", - "clamp": "MLETORCH-1864: Support non-fused clamp-type activations.", - "clamp_tensor": "MLETORCH-1864: Support non-fused clamp-type activations.", - "hardtanh": "MLETORCH-1864: Support non-fused clamp-type activations.", - "hardtanh_": "MLETORCH-1864: Support non-fused clamp-type activations.", - "relu": "MLETORCH-1864: Support non-fused clamp-type activations.", - "relu_": "MLETORCH-1864: Support non-fused clamp-type activations.", "eq_scalar": "MLETORCH-1865: Properly support flaky scalar comparison ops.", "ne_scalar": "MLETORCH-1865: Properly support flaky scalar comparison ops.", "ge_scalar": "MLETORCH-1865: Properly support flaky scalar comparison ops.", diff --git a/backends/cortex_m/test/models/test_nn_modules.py b/backends/cortex_m/test/models/test_nn_modules.py index 4a92fd578ff..303b481d4bc 100644 --- a/backends/cortex_m/test/models/test_nn_modules.py +++ b/backends/cortex_m/test/models/test_nn_modules.py @@ -1,6 +1,6 @@ -# Copyright 2025-2026 Arm Limited and/or its affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. diff --git a/backends/cortex_m/test/ops/test_activation.py b/backends/cortex_m/test/ops/test_activation.py index 8886a05a84b..0934386d67c 100644 --- a/backends/cortex_m/test/ops/test_activation.py +++ b/backends/cortex_m/test/ops/test_activation.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -398,6 +398,154 @@ def forward(self, x): return torch.clamp(self.linear(x), min=None, max=6.0) +class CortexMStandaloneReLU(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_relu_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_aten_clamp_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def forward(self, x): + return torch.relu(x) + + +class CortexMStandaloneHardtanh(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_aten_clamp_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def forward(self, x): + return torch.nn.functional.hardtanh(x, -1.0, 1.0) + + +class CortexMStandaloneClamp(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_clamp_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_aten_clamp_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def forward(self, x): + return torch.clamp(x, -1.0, 1.0) + + +class CortexMStandaloneClampTensor(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_clamp_Tensor": 1, + "executorch_exir_dialects_edge__ops_aten_full_like_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 4, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 5, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_aten_clamp_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def forward(self, x): + return torch.ops.aten.clamp.Tensor( + x, torch.full_like(x, -1.0), torch.full_like(x, 1.0) + ) + + +class CortexMMaxPool2DReLU(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1, + "executorch_exir_dialects_edge__ops_aten_relu_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_max_pool2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + ops_after_absent = ["executorch_exir_dialects_edge__ops_aten_relu_default"] + + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(self.pool(x)) + + +class CortexMMaxPool2DHardtanh(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_max_pool2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + ops_after_absent = ["executorch_exir_dialects_edge__ops_aten_hardtanh_default"] + + def __init__(self, min_val=-0.5, max_val=0.5): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.min_val = min_val + self.max_val = max_val + + def forward(self, x): + return torch.nn.functional.hardtanh(self.pool(x), self.min_val, self.max_val) + + +class CortexMMaxPool2DClamp(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1, + "executorch_exir_dialects_edge__ops_aten_clamp_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_max_pool2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + ops_after_absent = ["executorch_exir_dialects_edge__ops_aten_clamp_default"] + + def __init__(self, min_val=-0.25, max_val=0.75): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.min_val = min_val + self.max_val = max_val + + def forward(self, x): + return torch.clamp(self.pool(x), self.min_val, self.max_val) + + test_cases = { # Linear + activation tests with various data ranges "linear_relu_small_range": McuTestCase( @@ -509,6 +657,40 @@ def forward(self, x): model=CortexMLinearClamp(in_features=4, out_features=3), example_inputs=(ramp_tensor(-10, 10, (1, 4)),), ), + "standalone_relu": McuTestCase( + model=CortexMStandaloneReLU(), + example_inputs=(ramp_tensor(-5, 5, (2, 3, 4, 5)),), + ), + "standalone_hardtanh": McuTestCase( + model=CortexMStandaloneHardtanh(), + example_inputs=(ramp_tensor(-5, 5, (2, 3, 4, 5)),), + ), + "standalone_clamp": McuTestCase( + model=CortexMStandaloneClamp(), + example_inputs=(ramp_tensor(-5, 5, (2, 3, 4, 5)),), + ), + "standalone_clamp_tensor": McuTestCase( + model=CortexMStandaloneClampTensor(), + example_inputs=(ramp_tensor(-5, 5, (2, 3, 4, 5)),), + ), + "maxpool_relu": McuTestCase( + model=CortexMMaxPool2DReLU(), + example_inputs=( + ramp_tensor(-10, 10, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "maxpool_hardtanh": McuTestCase( + model=CortexMMaxPool2DHardtanh(), + example_inputs=( + ramp_tensor(-10, 10, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "maxpool_clamp": McuTestCase( + model=CortexMMaxPool2DClamp(), + example_inputs=( + ramp_tensor(-10, 10, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), } @@ -520,6 +702,8 @@ def test_dialect_activation(test_case): test_case.model.ops_after_transforms, qtol=1, ) + if hasattr(test_case.model, "ops_after_absent"): + tester.check_not(test_case.model.ops_after_absent) @parametrize("test_case", test_cases) diff --git a/backends/cortex_m/test/ops/test_conv_transpose.py b/backends/cortex_m/test/ops/test_conv_transpose.py index 7a91c5e1b6b..8202e3dc999 100644 --- a/backends/cortex_m/test/ops/test_conv_transpose.py +++ b/backends/cortex_m/test/ops/test_conv_transpose.py @@ -60,6 +60,61 @@ def forward(self, x): return self.conv_transpose(x) +class CortexMConvTranspose2DReLU(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_aten_relu_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_transpose_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + ops_after_absent = ["executorch_exir_dialects_edge__ops_aten_relu_default"] + + def __init__(self): + super().__init__() + self.conv_transpose = torch.nn.ConvTranspose2d( + 4, 2, kernel_size=3, stride=2, padding=1, bias=True + ) + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(self.conv_transpose(x)) + + +class CortexMConvTranspose2DHardtanh(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_transpose_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + ops_after_absent = ["executorch_exir_dialects_edge__ops_aten_hardtanh_default"] + + def __init__(self): + super().__init__() + self.conv_transpose = torch.nn.ConvTranspose2d( + 4, 2, kernel_size=3, stride=2, padding=1, bias=True + ) + + def forward(self, x): + return torch.nn.functional.hardtanh(self.conv_transpose(x), -0.5, 0.5) + + # Test cases covering various configurations test_cases = { # Basic test case @@ -123,6 +178,18 @@ def forward(self, x): ramp_tensor(0, 50, (1, 5, 4, 4)).to(memory_format=torch.channels_last), ), ), + "conv_transpose2d_relu": McuTestCase( + model=CortexMConvTranspose2DReLU(), + example_inputs=( + ramp_tensor(-10, 10, (1, 4, 4, 4)).to(memory_format=torch.channels_last), + ), + ), + "conv_transpose2d_hardtanh": McuTestCase( + model=CortexMConvTranspose2DHardtanh(), + example_inputs=( + ramp_tensor(-10, 10, (1, 4, 4, 4)).to(memory_format=torch.channels_last), + ), + ), # Dilation variation "conv_transpose2d_dilation_2": McuTestCase( model=CortexMConvTranspose2D(2, 4, kernel_size=3, dilation=2), @@ -244,12 +311,14 @@ def test_dialect_conv_transpose2d(test_case): test_case.model.ops_after_transforms, qtol=1, ) + if hasattr(test_case.model, "ops_after_absent"): + tester.check_not(test_case.model.ops_after_absent) -# Implementation xfails: empty because unsupported configurations are now -# rejected at AOT time by the quantizer filter, so they fall back to portable -# ops and work correctly. Only xfails_dialect needs to track these. -xfails_implementation: dict[str, xfail_type] = {} +xfails_implementation: dict[str, xfail_type] = { + "conv_transpose2d_relu": "Fused transpose-conv + relu lowers correctly but current implementation is numerically incorrect.", + "conv_transpose2d_hardtanh": "Fused transpose-conv + hardtanh lowers correctly but current implementation is numerically incorrect.", +} @parametrize("test_case", test_cases, xfails=xfails_implementation) From 2d995bccf075646884d95f97e6fa0ca080efe31f Mon Sep 17 00:00:00 2001 From: Per Held Date: Fri, 17 Apr 2026 13:09:49 +0200 Subject: [PATCH 02/17] Arm backend: Fix quantized constant-folding for aten.cat lists (#18971) FuseConstantArgsPass resolved input_qparams by flattened input-node index, while FoldAndAnnotateQParamsPass stores them by top-level argument index. For aten.cat with a list-valued tensor argument, this caused only the first tensor to be dequantized before folding, which corrupted the fused constant. Resolve qparams by top-level argument index and propagate that qparam through nested list and tuple arguments. Add a regression test for quantized aten.cat constant folding with list-valued tensor inputs. Signed-off-by: Per Held Change-Id: I6e1a012d82a5dbeecb403c440a2944953dd5cba7 --- .../arm/_passes/fuse_constant_ops_pass.py | 19 ++-- .../passes/test_fuse_constant_ops_pass.py | 96 +++++++++++++++++++ 2 files changed, 107 insertions(+), 8 deletions(-) diff --git a/backends/arm/_passes/fuse_constant_ops_pass.py b/backends/arm/_passes/fuse_constant_ops_pass.py index 6fd9b145988..d6fd4b18b53 100644 --- a/backends/arm/_passes/fuse_constant_ops_pass.py +++ b/backends/arm/_passes/fuse_constant_ops_pass.py @@ -83,21 +83,24 @@ def _fuse_nodes(self, node) -> bool: input_nodes = list(node.all_input_nodes) qparams = node.meta.get("input_qparams", None) - def resolve_arg(arg): + def resolve_arg(arg, arg_index=None): + qparam = ( + qparams.get(arg_index) if qparams and arg_index is not None else None + ) if isinstance(arg, torch.fx.Node) and arg in input_nodes: - idx = input_nodes.index(arg) t = get_param_tensor(self.exported_program, arg) - # Check if qparams exist for this arg - if qparams and idx in qparams.keys(): - t = qparams[idx].dequantize_value(t) + if qparam is not None: + t = qparam.dequantize_value(t) return t if isinstance(arg, tuple): - return tuple(resolve_arg(x) for x in arg) + return tuple(resolve_arg(x, arg_index) for x in arg) if isinstance(arg, list): - return [resolve_arg(x) for x in arg] + return [resolve_arg(x, arg_index) for x in arg] return arg - new_args = tuple(resolve_arg(a) for a in node.args) + new_args = tuple( + resolve_arg(arg, arg_index) for arg_index, arg in enumerate(node.args) + ) new_kwargs = {k: resolve_arg(v) for k, v in node.kwargs.items()} data = node.target(*new_args, **new_kwargs) diff --git a/backends/arm/test/passes/test_fuse_constant_ops_pass.py b/backends/arm/test/passes/test_fuse_constant_ops_pass.py index 0f281dba24b..785744c1b37 100644 --- a/backends/arm/test/passes/test_fuse_constant_ops_pass.py +++ b/backends/arm/test/passes/test_fuse_constant_ops_pass.py @@ -11,8 +11,12 @@ ComputeConstantOpsAOTPass, FuseConstantArgsPass, ) +from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.backends.arm.test.tester.test_pipeline import PassPipeline +from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.test.harness.stages import StageType input_t = Tuple[torch.Tensor] # Input x input_t2 = Tuple[torch.Tensor, torch.Tensor] @@ -116,6 +120,52 @@ def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return torch.cat((a, b), dim=0) +class QuantizedCatConstantBuffers(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer( + "horizontal_ramp", + torch.tensor( + [ + [ + [ + [-95, -32, 32, 95, 0], + [-95, -32, 32, 95, 0], + [-95, -32, 32, 95, 0], + [-95, -32, 32, 95, 0], + ] + ] + ], + dtype=torch.int8, + ), + ) + self.register_buffer( + "vertical_ramp", + torch.tensor( + [ + [ + [ + [-95, -95, -95, -95, -95], + [-32, -32, -32, -32, -32], + [32, 32, 32, 32, 32], + [95, 95, 95, 95, 95], + ] + ] + ], + dtype=torch.int8, + ), + ) + + def forward(self) -> torch.Tensor: + return torch.cat( + ( + cast(torch.Tensor, self.horizontal_ramp), + cast(torch.Tensor, self.vertical_ramp), + ), + dim=1, + ) + + modules: Dict[str, ModuleWithFuseAttrs] = { "fuse_parameter": cast(ModuleWithFuseAttrs, FuseParameter()), "fuse_buffer": cast(ModuleWithFuseAttrs, FuseBuffer()), @@ -174,3 +224,49 @@ def test_fuse_constant_args_tosa_INT_cat(module: ModuleWithFuseAttrs) -> None: ], ) pipeline.run() + + +def test_fuse_constant_args_tosa_INT_cat_uses_top_level_arg_qparams() -> None: + qargs = QuantArgs( + scale=1.0 / 127.0, + zp=0, + qmin=-127, + qmax=127, + dtype=torch.int8, + ) + module = QuantizedCatConstantBuffers() + compile_spec = common.get_tosa_compile_spec( + TosaSpecification.create_from_string("TOSA-1.0+FP") + ) + tester = ArmTester(module, example_inputs=(), compile_spec=compile_spec) + tester.export().to_edge() + exported_program = tester.get_artifact(StageType.TO_EDGE).exported_program() + + cat_node = next( + node + for node in exported_program.graph_module.graph.nodes + if node.op == "call_function" + ) + cat_node.meta["input_qparams"] = {0: qargs} + cat_node.meta["output_qparams"] = {0: qargs} + + pass_result = FuseConstantArgsPass(exported_program).call( + exported_program.graph_module + ) + + assert list(exported_program.state_dict) == ["aten_cat_default_fused_const"] + torch.testing.assert_close( + exported_program.state_dict["aten_cat_default_fused_const"], + torch.cat( + ( + cast(torch.Tensor, module.horizontal_ramp), + cast(torch.Tensor, module.vertical_ramp), + ), + dim=1, + ), + ) + assert [ + node.name + for node in pass_result.graph_module.graph.nodes + if node.op == "placeholder" + ] == ["aten_cat_default_fused_const"] From 8a77f9bb3de58f870e2d1d421588dc37ae44317c Mon Sep 17 00:00:00 2001 From: Akshara Bhardwaj Date: Fri, 24 Apr 2026 01:58:39 +0530 Subject: [PATCH 03/17] Format third-party/CMakeLists.txt using cmake-format (#18533) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #10736 Formats `third-party/CMakeLists.txt` using `cmake-format` to improve readability and consistency. **Changes:** - Reformatted `ExternalProject_Add(...)` blocks for `flatbuffers` and `flatcc` - Reflowed `set_target_properties(...)`, `set(...)` cache variables, and `install(...)` calls - No functional changes — formatting only --- third-party/CMakeLists.txt | 130 ++++++++++++++++++++++--------------- 1 file changed, 78 insertions(+), 52 deletions(-) diff --git a/third-party/CMakeLists.txt b/third-party/CMakeLists.txt index 93ce08bdc7d..bafcd74ec77 100644 --- a/third-party/CMakeLists.txt +++ b/third-party/CMakeLists.txt @@ -24,8 +24,11 @@ endif() if(WIN32) set(_executorch_external_project_additional_args) else() - # Always use Make to avoid needing to codesign flatc if the project is using Xcode. - set(_executorch_external_project_additional_args CMAKE_GENERATOR "Unix Makefiles") + # Always use Make to avoid needing to codesign flatc if the project is using + # Xcode. + set(_executorch_external_project_additional_args CMAKE_GENERATOR + "Unix Makefiles" + ) endif() # We use ExternalProject to build flatc from source to force it target the host. @@ -35,93 +38,119 @@ ExternalProject_Add( PREFIX ${CMAKE_CURRENT_BINARY_DIR}/flatc_ep BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/flatc_ep/src/build SOURCE_DIR ${PROJECT_SOURCE_DIR}/third-party/flatbuffers - CMAKE_ARGS -DFLATBUFFERS_BUILD_FLATC=ON - -DFLATBUFFERS_INSTALL=ON - -DFLATBUFFERS_BUILD_FLATHASH=OFF - -DFLATBUFFERS_BUILD_FLATLIB=OFF - -DFLATBUFFERS_BUILD_TESTS=OFF - -DCMAKE_INSTALL_PREFIX:PATH= - -DCMAKE_CXX_FLAGS="-DFLATBUFFERS_MAX_ALIGNMENT=${EXECUTORCH_FLATBUFFERS_MAX_ALIGNMENT}" - # Unset the toolchain to build for the host instead of the toolchain set for the project. - -DCMAKE_TOOLCHAIN_FILE= - # If building for iOS, "unset" these variables to rely on the host (macOS) defaults. - $<$,$>>:-DCMAKE_OSX_SYSROOT=> - -DCMAKE_OSX_DEPLOYMENT_TARGET:STRING=${CMAKE_OSX_DEPLOYMENT_TARGET} + CMAKE_ARGS + -DFLATBUFFERS_BUILD_FLATC=ON + -DFLATBUFFERS_INSTALL=ON + -DFLATBUFFERS_BUILD_FLATHASH=OFF + -DFLATBUFFERS_BUILD_FLATLIB=OFF + -DFLATBUFFERS_BUILD_TESTS=OFF + -DCMAKE_INSTALL_PREFIX:PATH= + -DCMAKE_CXX_FLAGS="-DFLATBUFFERS_MAX_ALIGNMENT=${EXECUTORCH_FLATBUFFERS_MAX_ALIGNMENT}" + # Unset the toolchain to build for the host instead of the toolchain set for + # the project. + -DCMAKE_TOOLCHAIN_FILE= + # If building for iOS, "unset" these variables to rely on the host (macOS) + # defaults. + $<$,$>>:-DCMAKE_OSX_SYSROOT=> + -DCMAKE_OSX_DEPLOYMENT_TARGET:STRING=${CMAKE_OSX_DEPLOYMENT_TARGET} BUILD_BYPRODUCTS /bin/flatc - ${_executorch_external_project_additional_args} + ${_executorch_external_project_additional_args} ) ExternalProject_Get_Property(flatbuffers_ep INSTALL_DIR) add_executable(flatc IMPORTED GLOBAL) add_dependencies(flatc flatbuffers_ep) if(WIN32 AND NOT CMAKE_CROSSCOMPILING) - # flatbuffers does not use CMAKE_BUILD_TYPE. Internally, the build forces Release - # config, but from CMake's perspective the build type is always Debug. - set_target_properties(flatc PROPERTIES IMPORTED_LOCATION ${INSTALL_DIR}/bin/flatc.exe) + # flatbuffers does not use CMAKE_BUILD_TYPE. Internally, the build forces + # Release config, but from CMake's perspective the build type is always Debug. + set_target_properties( + flatc PROPERTIES IMPORTED_LOCATION ${INSTALL_DIR}/bin/flatc.exe + ) else() - set_target_properties(flatc PROPERTIES IMPORTED_LOCATION ${INSTALL_DIR}/bin/flatc) + set_target_properties( + flatc PROPERTIES IMPORTED_LOCATION ${INSTALL_DIR}/bin/flatc + ) endif() # TODO: re-enable once flatbuffers is added as a subdirectory. -# set(FLATBUFFERS_BUILD_FLATC OFF) -# set(FLATBUFFERS_INSTALL OFF) -# set(FLATBUFFERS_BUILD_FLATHASH OFF) -# set(FLATBUFFERS_BUILD_FLATLIB OFF) +# set(FLATBUFFERS_BUILD_FLATC OFF) set(FLATBUFFERS_INSTALL OFF) +# set(FLATBUFFERS_BUILD_FLATHASH OFF) set(FLATBUFFERS_BUILD_FLATLIB OFF) # set(FLATBUFFERS_BUILD_TESTS OFF) # MARK: - flatcc if(WIN32) # For some reason, when configuring the external project during build - # CMAKE_C_SIMULATE_ID is set to MSVC, but CMAKE_CXX_SIMULATE_ID is not set. - # To make sure the external project is configured correctly, set it explicitly + # CMAKE_C_SIMULATE_ID is set to MSVC, but CMAKE_CXX_SIMULATE_ID is not set. To + # make sure the external project is configured correctly, set it explicitly # here. set(_flatcc_extra_cmake_args -DCMAKE_CXX_SIMULATE_ID=MSVC) else() set(_flatcc_extra_cmake_args) endif() -# Similar to flatbuffers, we want to build flatcc for the host. See inline comments -# in the flatbuffers ExternalProject_Add for more details. +# Similar to flatbuffers, we want to build flatcc for the host. See inline +# comments in the flatbuffers ExternalProject_Add for more details. ExternalProject_Add( flatcc_ep PREFIX ${CMAKE_CURRENT_BINARY_DIR}/flatcc_ep SOURCE_DIR ${PROJECT_SOURCE_DIR}/third-party/flatcc BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/flatcc_ep/src/build - CMAKE_ARGS -DFLATCC_RTONLY=OFF - -DFLATCC_TEST=OFF - -DFLATCC_REFLECTION=OFF - -DFLATCC_DEBUG_CLANG_SANITIZE=OFF - -DFLATCC_INSTALL=ON - -DCMAKE_POLICY_VERSION_MINIMUM=3.5 - -DCMAKE_INSTALL_PREFIX:PATH= - -DCMAKE_POSITION_INDEPENDENT_CODE=ON - -DCMAKE_TOOLCHAIN_FILE= - $<$,$>>:-DCMAKE_OSX_SYSROOT=> - -DCMAKE_OSX_DEPLOYMENT_TARGET:STRING=${CMAKE_OSX_DEPLOYMENT_TARGET} - ${_flatcc_extra_cmake_args} + CMAKE_ARGS + -DFLATCC_RTONLY=OFF + -DFLATCC_TEST=OFF + -DFLATCC_REFLECTION=OFF + -DFLATCC_DEBUG_CLANG_SANITIZE=OFF + -DFLATCC_INSTALL=ON + -DCMAKE_POLICY_VERSION_MINIMUM=3.5 + -DCMAKE_INSTALL_PREFIX:PATH= + -DCMAKE_POSITION_INDEPENDENT_CODE=ON + -DCMAKE_TOOLCHAIN_FILE= + $<$,$>>:-DCMAKE_OSX_SYSROOT=> + -DCMAKE_OSX_DEPLOYMENT_TARGET:STRING=${CMAKE_OSX_DEPLOYMENT_TARGET} + ${_flatcc_extra_cmake_args} BUILD_BYPRODUCTS /bin/flatcc - {_executorch_external_project_additional_args} + ${_executorch_external_project_additional_args} ) file(REMOVE_RECURSE ${PROJECT_SOURCE_DIR}/third-party/flatcc/lib) ExternalProject_Get_Property(flatcc_ep INSTALL_DIR) add_executable(flatcc_cli IMPORTED GLOBAL) add_dependencies(flatcc_cli flatcc_ep) if(WIN32 AND NOT CMAKE_CROSSCOMPILING) - set_target_properties(flatcc_cli PROPERTIES IMPORTED_LOCATION ${INSTALL_DIR}/bin/flatcc.exe) + set_target_properties( + flatcc_cli PROPERTIES IMPORTED_LOCATION ${INSTALL_DIR}/bin/flatcc.exe + ) else() - set_target_properties(flatcc_cli PROPERTIES IMPORTED_LOCATION ${INSTALL_DIR}/bin/flatcc) + set_target_properties( + flatcc_cli PROPERTIES IMPORTED_LOCATION ${INSTALL_DIR}/bin/flatcc + ) endif() -set(FLATCC_RTONLY ON CACHE BOOL "") -set(FLATCC_TEST OFF CACHE BOOL "") -set(FLATCC_REFLECTION OFF CACHE BOOL "") -set(FLATCC_DEBUG_CLANG_SANITIZE OFF CACHE BOOL "") -set(FLATCC_INSTALL OFF CACHE BOOL "") +set(FLATCC_RTONLY + ON + CACHE BOOL "" +) +set(FLATCC_TEST + OFF + CACHE BOOL "" +) +set(FLATCC_REFLECTION + OFF + CACHE BOOL "" +) +set(FLATCC_DEBUG_CLANG_SANITIZE + OFF + CACHE BOOL "" +) +set(FLATCC_INSTALL + OFF + CACHE BOOL "" +) add_subdirectory(flatcc) # Unfortunately flatcc writes libs directly in to the source tree [1]. So to # ensure the target lib is created last, force flatcc_cli to build first. # -# [1] https://github.com/dvidelabs/flatcc/blob/896db54787e8b730a6be482c69324751f3f5f117/CMakeLists.txt#L168 +# [1] +# https://github.com/dvidelabs/flatcc/blob/896db54787e8b730a6be482c69324751f3f5f117/CMakeLists.txt#L168 add_dependencies(flatccrt flatcc_cli) # Fix for "relocation R_X86_64_32 against `.rodata' can not be used when making # a shared object; recompile with -fPIC" when building on some x86 linux @@ -129,7 +158,4 @@ add_dependencies(flatccrt flatcc_cli) # # Learn more: https://github.com/pytorch/executorch/pull/2467 set_property(TARGET flatccrt PROPERTY POSITION_INDEPENDENT_CODE ON) -install( - TARGETS flatccrt - DESTINATION ${CMAKE_BINARY_DIR}/lib -) +install(TARGETS flatccrt DESTINATION ${CMAKE_BINARY_DIR}/lib) From 3ec63f41e6fed20192e853e1aa35262d6dfdbf0b Mon Sep 17 00:00:00 2001 From: Siddartha Pothapragada Date: Thu, 23 Apr 2026 14:18:33 -0700 Subject: [PATCH 04/17] Ignored Module tests: provide required input tensor (#19028) All 4 tests failed because they called forward() with zero arguments on mobilenet_v2 which expects a [1,3,224,224] float input. This was a test bug, not a runtime bug. Add a dummyInput() helper that creates a Tensor.ones with the correct shape, and remove all @Ignore annotations. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../executorch/ModuleInstrumentationTest.kt | 70 ++++++++++--------- 1 file changed, 37 insertions(+), 33 deletions(-) diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt index ba91f444287..eb2b6f096a1 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt @@ -17,7 +17,6 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.commons.io.FileUtils import org.junit.Assert import org.junit.Before -import org.junit.Ignore import org.junit.Test import org.junit.runner.RunWith import org.pytorch.executorch.TestFileUtils.getTestFilePath @@ -40,48 +39,49 @@ class ModuleInstrumentationTest { inputStream.close() } - @Ignore( - "The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward " - ) @Test @Throws(IOException::class, URISyntaxException::class) fun testModuleLoadAndForward() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - val results = module.forward() - Assert.assertTrue(results[0].isTensor) + try { + val results = module.forward(EValue.from(dummyInput())) + Assert.assertTrue(results[0].isTensor) + } finally { + module.destroy() + } } @Test @Throws(IOException::class, URISyntaxException::class) fun testMethodMetadata() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + module.destroy() } - @Ignore( - "The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward " - ) @Test @Throws(IOException::class) fun testModuleLoadMethodAndForward() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + try { + module.loadMethod(FORWARD_METHOD) - module.loadMethod(FORWARD_METHOD) - - val results = module.forward() - Assert.assertTrue(results[0].isTensor) + val results = module.forward(EValue.from(dummyInput())) + Assert.assertTrue(results[0].isTensor) + } finally { + module.destroy() + } } - @Ignore( - "The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward " - ) @Test @Throws(IOException::class) fun testModuleLoadForwardExplicit() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - val results = module.execute(FORWARD_METHOD) - Assert.assertTrue(results[0].isTensor) + try { + val results = module.execute(FORWARD_METHOD, EValue.from(dummyInput())) + Assert.assertTrue(results[0].isTensor) + } finally { + module.destroy() + } } @Test(expected = RuntimeException::class) @@ -94,15 +94,18 @@ class ModuleInstrumentationTest { @Throws(IOException::class) fun testModuleLoadMethodNonExistantMethod() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - val exception = - Assert.assertThrows(ExecutorchRuntimeException::class.java) { - module.loadMethod(NONE_METHOD) - } - Assert.assertEquals( - ExecutorchRuntimeException.INVALID_ARGUMENT, - exception.getErrorCode(), - ) + try { + val exception = + Assert.assertThrows(ExecutorchRuntimeException::class.java) { + module.loadMethod(NONE_METHOD) + } + Assert.assertEquals( + ExecutorchRuntimeException.INVALID_ARGUMENT, + exception.getErrorCode(), + ) + } finally { + module.destroy() + } } @Test(expected = RuntimeException::class) @@ -135,9 +138,6 @@ class ModuleInstrumentationTest { Assert.assertThrows(IllegalStateException::class.java) { module.forward() } } - @Ignore( - "The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward " - ) @Test @Throws(InterruptedException::class, IOException::class) fun testForwardFromMultipleThreads() { @@ -151,7 +151,7 @@ class ModuleInstrumentationTest { try { latch.countDown() latch.await(5000, TimeUnit.MILLISECONDS) - val results = module.forward() + val results = module.forward(EValue.from(dummyInput())) Assert.assertTrue(results[0].isTensor) completed.incrementAndGet() } catch (_: InterruptedException) {} @@ -168,6 +168,7 @@ class ModuleInstrumentationTest { } Assert.assertEquals(numThreads.toLong(), completed.get().toLong()) + module.destroy() } companion object { @@ -176,5 +177,8 @@ class ModuleInstrumentationTest { private const val NON_PTE_FILE_NAME = "/test.txt" private const val FORWARD_METHOD = "forward" private const val NONE_METHOD = "none" + private val inputShape = longArrayOf(1, 3, 224, 224) + + private fun dummyInput(): Tensor = Tensor.ones(inputShape, DType.FLOAT) } } From 6d23e4168f27180ffcc8be421c57c4fb6011283c Mon Sep 17 00:00:00 2001 From: YIWENX14 <164585414+YIWENX14@users.noreply.github.com> Date: Thu, 23 Apr 2026 14:36:25 -0700 Subject: [PATCH 05/17] Extract shared multifunction PTE utilities to utils.py (#19035) Differential Revision: D101887672 Pull Request resolved: https://github.com/pytorch/executorch/pull/19035 --- .../llama/run_static_llm_multifunction.py | 202 +--------------- examples/apple/coreml/llama/utils.py | 218 ++++++++++++++++++ 2 files changed, 228 insertions(+), 192 deletions(-) diff --git a/examples/apple/coreml/llama/run_static_llm_multifunction.py b/examples/apple/coreml/llama/run_static_llm_multifunction.py index 517c54435f4..98d0cb0a763 100644 --- a/examples/apple/coreml/llama/run_static_llm_multifunction.py +++ b/examples/apple/coreml/llama/run_static_llm_multifunction.py @@ -22,14 +22,16 @@ import argparse import json import time -from typing import Any, Dict, List, Tuple +from typing import List import torch -import torch.utils._pytree as pytree +from executorch.examples.apple.coreml.llama.utils import ( + create_pte_wrapper, + setup_multifunction_managers, +) from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.runner.generation import next_token -from executorch.examples.models.llama.static_attention import StaticAttentionIOManager from executorch.runtime import Runtime from pytorch_tokenizers import get_tokenizer @@ -41,170 +43,6 @@ def get_stop_tokens(tokenizer) -> List[int]: return [tokenizer.eos_id] -def create_pte_wrapper( - decode_method, - prefill_method, - mgr: "StaticAttentionIOManager", - prefill_seq_len: int, - prefill_mask: Dict[str, torch.Tensor], -): - """ - Create a wrapper function that adapts PTE execution to the interface - expected by StaticAttentionIOManager. - - This multifunction version selects between prefill and decode methods - based on the input sequence length. Both methods use the SAME cache_len, - so the cache buffer is shared directly without any slicing or copying. - - The wrapper: - - Takes (tokens, options_dict) like the eager model - - Selects prefill or decode method based on token count - - Uses the same cache buffer for both methods (no slicing needed) - - Flattens inputs using pytree - - Executes the appropriate PTE method - - Reconstructs outputs to match eager model format: (logits, {"out_cache_state": (k_dict, v_dict)}) - - Args: - decode_method: The PTE method for decode (seqlen=1) - prefill_method: The PTE method for prefill (seqlen=input_len) - mgr: StaticAttentionIOManager with caches sized for shared cache_len - prefill_seq_len: The sequence length for prefill - prefill_mask: Pre-computed mask tensor for prefill method - """ - - k_cache_keys = list(mgr.k_caches.keys()) - v_cache_keys = list(mgr.v_caches.keys()) - - timing_stats = { - "flatten_time": 0.0, - "execute_time": 0.0, - "reconstruct_time": 0.0, - "detection_time": 0.0, - "options_build_time": 0.0, - "call_count": 0, - } - - def wrapper( - tokens: torch.Tensor, options: Dict[str, Any] - ) -> Tuple[torch.Tensor, Dict[str, Any]]: - import time as time_module - - timing_stats["call_count"] += 1 - - t0 = time_module.perf_counter() - - # Detect actual sequence length. - # StaticAttentionIOManager._run_once pads tokens with zeros on the right. - # For decode (1 actual token), positions 1+ are all zeros. - padded_seq_len = tokens.shape[1] - if padded_seq_len > 1 and (tokens[0, 1:] == 0).all(): - actual_seq_len = 1 - else: - actual_seq_len = padded_seq_len - - is_prefill = actual_seq_len == prefill_seq_len - - t1 = time_module.perf_counter() - timing_stats["detection_time"] += t1 - t0 - - t0 = time_module.perf_counter() - - # Get the input cache state from options - in_k_caches, in_v_caches = options["in_cache_state"] - - # Both prefill and decode use the same cache_len, so no slicing needed! - # Just select the appropriate method and mask. - if is_prefill: - method = prefill_method - adapted_mask = prefill_mask - else: - method = decode_method - adapted_mask = mgr.masks - - adapted_options = { - "masks": adapted_mask, - "freqs_cos_override": options["freqs_cos_override"], - "freqs_sin_override": options["freqs_sin_override"], - "in_cache_state": (in_k_caches, in_v_caches), # Same cache for both! - } - - if "last_valid_token_pos" in options: - adapted_options["last_valid_token_pos"] = options["last_valid_token_pos"] - - inputs = (tokens, adapted_options) - - t1 = time_module.perf_counter() - timing_stats["options_build_time"] += t1 - t0 - - t0 = time_module.perf_counter() - flat_inputs, _ = pytree.tree_flatten(inputs) - t1 = time_module.perf_counter() - timing_stats["flatten_time"] += t1 - t0 - - t0 = time_module.perf_counter() - outputs = method.execute(flat_inputs) - t1 = time_module.perf_counter() - timing_stats["execute_time"] += t1 - t0 - - t0 = time_module.perf_counter() - - logits = outputs[0] - - num_layers = len(k_cache_keys) - k_updates = outputs[1 : 1 + num_layers] - v_updates = outputs[1 + num_layers : 1 + 2 * num_layers] - - k_cache_dict = dict(zip(k_cache_keys, k_updates)) - v_cache_dict = dict(zip(v_cache_keys, v_updates)) - - attn_updates = {"out_cache_state": (k_cache_dict, v_cache_dict)} - - t1 = time_module.perf_counter() - timing_stats["reconstruct_time"] += t1 - t0 - - return logits, attn_updates - - def print_timing_stats(): - n = timing_stats["call_count"] - if n > 0: - print(f"\n=== Wrapper Timing Stats ({n} calls) ===") - print( - f" Detection time: {timing_stats['detection_time']*1000:.2f}ms total, {timing_stats['detection_time']/n*1000:.4f}ms avg" - ) - print( - f" Options build: {timing_stats['options_build_time']*1000:.2f}ms total, {timing_stats['options_build_time']/n*1000:.4f}ms avg" - ) - print( - f" Flatten time: {timing_stats['flatten_time']*1000:.2f}ms total, {timing_stats['flatten_time']/n*1000:.4f}ms avg" - ) - print( - f" Execute time: {timing_stats['execute_time']*1000:.2f}ms total, {timing_stats['execute_time']/n*1000:.3f}ms avg" - ) - print( - f" Reconstruct time: {timing_stats['reconstruct_time']*1000:.2f}ms total, {timing_stats['reconstruct_time']/n*1000:.4f}ms avg" - ) - total = ( - timing_stats["detection_time"] - + timing_stats["options_build_time"] - + timing_stats["flatten_time"] - + timing_stats["execute_time"] - + timing_stats["reconstruct_time"] - ) - print( - f" Total wrapper: {total*1000:.2f}ms total, {total/n*1000:.3f}ms avg" - ) - print( - f" Execute is {timing_stats['execute_time']/total*100:.1f}% of wrapper time" - ) - expected_tps = 1000 / (timing_stats["execute_time"] / n * 1000) - print(f" Expected tok/s from execute alone: {expected_tps:.1f}") - - wrapper.print_timing_stats = print_timing_stats - wrapper.timing_stats = timing_stats - - return wrapper - - def main(): parser = argparse.ArgumentParser( description="Run multifunction static attention Llama model" @@ -326,36 +164,16 @@ def main(): print(f"Prefill: input_len={prefill_input_len}, cache_len={shared_cache_len}") print(f"Decode: input_len={decode_input_len}, cache_len={shared_cache_len}") - # Create decode manager (input_len=1) - used for decode phase - mgr = StaticAttentionIOManager( - model_args, - input_len=decode_input_len, - cache_lens=shared_cache_len, - batch_size=1, - dtype=torch.float16, - style="smart_mask", - mask_val=float("-inf"), - ) - - # Create prefill manager (input_len=64) with the SAME cache_len. - # Since both use the same cache_len, we can share the cache buffer directly. - prefill_mgr = StaticAttentionIOManager( + # Create managers with shared cache buffers + mgr, prefill_mgr, prefill_mask = setup_multifunction_managers( model_args, - input_len=prefill_input_len, - cache_lens=shared_cache_len, # Same cache_len as decode! - batch_size=1, + prefill_input_len, + decode_input_len, + shared_cache_len, dtype=torch.float16, - style="smart_mask", mask_val=float("-inf"), ) - # Share cache buffers: point prefill_mgr's caches to mgr's caches. - # No copying needed since both managers use the same cache_len! - prefill_mgr.k_caches = mgr.k_caches - prefill_mgr.v_caches = mgr.v_caches - - prefill_mask = prefill_mgr.masks - # Load PTE model with multifunction support print(f"Loading multifunction model from {args.model}...") runtime = Runtime.get() diff --git a/examples/apple/coreml/llama/utils.py b/examples/apple/coreml/llama/utils.py index 1e5a842fed5..755a654b9df 100644 --- a/examples/apple/coreml/llama/utils.py +++ b/examples/apple/coreml/llama/utils.py @@ -4,7 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import time +from typing import Any, Dict, Tuple, TYPE_CHECKING + import torch +import torch.utils._pytree as pytree + +if TYPE_CHECKING: + from executorch.examples.models.llama.static_attention import ( + StaticAttentionIOManager, + ) class SplitLinearModule(torch.nn.Module): @@ -114,3 +123,212 @@ def replace_linear_with_split_linear( in_target_split_size, in_max_splits, ) + + +def setup_multifunction_managers( + config, + prefill_input_len: int, + decode_input_len: int, + shared_cache_len: int, + dtype: torch.dtype = torch.float16, + mask_val: float = float("-inf"), + style: str = "smart_mask", +): + """ + Create prefill and decode StaticAttentionIOManager instances with shared cache buffers. + + Both managers use the same cache_len so they share cache memory directly. + Returns (decode_mgr, prefill_mgr, prefill_mask). + """ + from executorch.examples.models.llama.static_attention import ( + StaticAttentionIOManager, + ) + + mgr = StaticAttentionIOManager( + config, + input_len=decode_input_len, + cache_lens=shared_cache_len, + batch_size=1, + dtype=dtype, + style=style, + mask_val=mask_val, + ) + + prefill_mgr = StaticAttentionIOManager( + config, + input_len=prefill_input_len, + cache_lens=shared_cache_len, + batch_size=1, + dtype=dtype, + style=style, + mask_val=mask_val, + ) + + # Share cache buffers — no copying needed + prefill_mgr.k_caches = mgr.k_caches + prefill_mgr.v_caches = mgr.v_caches + prefill_mask = prefill_mgr.masks + + return mgr, prefill_mgr, prefill_mask + + +def create_pte_wrapper( + decode_method, + prefill_method, + mgr: "StaticAttentionIOManager", + prefill_seq_len: int, + prefill_mask: Dict[str, torch.Tensor], +): + """ + Create a wrapper function that adapts PTE execution to the interface + expected by StaticAttentionIOManager. + + This multifunction version selects between prefill and decode methods + based on the input sequence length. Both methods use the SAME cache_len, + so the cache buffer is shared directly without any slicing or copying. + + The wrapper: + - Takes (tokens, options_dict) like the eager model + - Selects prefill or decode method based on token count + - Uses the same cache buffer for both methods (no slicing needed) + - Flattens inputs using pytree + - Executes the appropriate PTE method + - Reconstructs outputs to match eager model format: (logits, {"out_cache_state": (k_dict, v_dict)}) + + Args: + decode_method: The PTE method for decode (seqlen=1) + prefill_method: The PTE method for prefill (seqlen=input_len) + mgr: StaticAttentionIOManager with caches sized for shared cache_len + prefill_seq_len: The sequence length for prefill + prefill_mask: Pre-computed mask tensor for prefill method + """ + + k_cache_keys = list(mgr.k_caches.keys()) + v_cache_keys = list(mgr.v_caches.keys()) + + timing_stats = { + "flatten_time": 0.0, + "execute_time": 0.0, + "reconstruct_time": 0.0, + "detection_time": 0.0, + "options_build_time": 0.0, + "call_count": 0, + } + + def wrapper( + tokens: torch.Tensor, options: Dict[str, Any] + ) -> Tuple[torch.Tensor, Dict[str, Any]]: + timing_stats["call_count"] += 1 + + t0 = time.perf_counter() + + # Detect actual sequence length. + # StaticAttentionIOManager._run_once pads tokens with zeros on the right. + # For decode (1 actual token), positions 1+ are all zeros. + padded_seq_len = tokens.shape[1] + if padded_seq_len > 1 and (tokens[0, 1:] == 0).all(): + actual_seq_len = 1 + else: + actual_seq_len = padded_seq_len + + is_prefill = actual_seq_len == prefill_seq_len + + t1 = time.perf_counter() + timing_stats["detection_time"] += t1 - t0 + + t0 = time.perf_counter() + + # Get the input cache state from options + in_k_caches, in_v_caches = options["in_cache_state"] + + # Both prefill and decode use the same cache_len, so no slicing needed! + # Just select the appropriate method and mask. + if is_prefill: + method = prefill_method + adapted_mask = prefill_mask + else: + method = decode_method + adapted_mask = mgr.masks + + adapted_options = { + "masks": adapted_mask, + "freqs_cos_override": options["freqs_cos_override"], + "freqs_sin_override": options["freqs_sin_override"], + "in_cache_state": (in_k_caches, in_v_caches), # Same cache for both! + } + + if "last_valid_token_pos" in options: + adapted_options["last_valid_token_pos"] = options["last_valid_token_pos"] + + inputs = (tokens, adapted_options) + + t1 = time.perf_counter() + timing_stats["options_build_time"] += t1 - t0 + + t0 = time.perf_counter() + flat_inputs, _ = pytree.tree_flatten(inputs) + t1 = time.perf_counter() + timing_stats["flatten_time"] += t1 - t0 + + t0 = time.perf_counter() + outputs = method.execute(flat_inputs) + t1 = time.perf_counter() + timing_stats["execute_time"] += t1 - t0 + + t0 = time.perf_counter() + + logits = outputs[0] + + num_layers = len(k_cache_keys) + k_updates = outputs[1 : 1 + num_layers] + v_updates = outputs[1 + num_layers : 1 + 2 * num_layers] + + k_cache_dict = dict(zip(k_cache_keys, k_updates)) + v_cache_dict = dict(zip(v_cache_keys, v_updates)) + + attn_updates = {"out_cache_state": (k_cache_dict, v_cache_dict)} + + t1 = time.perf_counter() + timing_stats["reconstruct_time"] += t1 - t0 + + return logits, attn_updates + + def print_timing_stats(): + n = timing_stats["call_count"] + if n > 0: + print(f"\n=== Wrapper Timing Stats ({n} calls) ===") + print( + f" Detection time: {timing_stats['detection_time']*1000:.2f}ms total, {timing_stats['detection_time']/n*1000:.4f}ms avg" + ) + print( + f" Options build: {timing_stats['options_build_time']*1000:.2f}ms total, {timing_stats['options_build_time']/n*1000:.4f}ms avg" + ) + print( + f" Flatten time: {timing_stats['flatten_time']*1000:.2f}ms total, {timing_stats['flatten_time']/n*1000:.4f}ms avg" + ) + print( + f" Execute time: {timing_stats['execute_time']*1000:.2f}ms total, {timing_stats['execute_time']/n*1000:.3f}ms avg" + ) + print( + f" Reconstruct time: {timing_stats['reconstruct_time']*1000:.2f}ms total, {timing_stats['reconstruct_time']/n*1000:.4f}ms avg" + ) + total = ( + timing_stats["detection_time"] + + timing_stats["options_build_time"] + + timing_stats["flatten_time"] + + timing_stats["execute_time"] + + timing_stats["reconstruct_time"] + ) + print( + f" Total wrapper: {total*1000:.2f}ms total, {total/n*1000:.3f}ms avg" + ) + print( + f" Execute is {timing_stats['execute_time']/total*100:.1f}% of wrapper time" + ) + expected_tps = 1000 / (timing_stats["execute_time"] / n * 1000) + print(f" Expected tok/s from execute alone: {expected_tps:.1f}") + + wrapper.print_timing_stats = print_timing_stats + wrapper.timing_stats = timing_stats + + return wrapper From 7b5dcc18a3911a5e74547d0c8ebece1d6e0cb5f1 Mon Sep 17 00:00:00 2001 From: mcremon-meta <134334895+mcremon-meta@users.noreply.github.com> Date: Thu, 23 Apr 2026 14:49:34 -0700 Subject: [PATCH 06/17] Add add-relu fusion in the quantizer Differential Revision: D102189156 Pull Request resolved: https://github.com/pytorch/executorch/pull/19077 --- backends/cadence/aot/quantizer/fusion_pass.py | 18 +++++- backends/cadence/aot/quantizer/patterns.py | 55 +++++++++++++++++++ backends/cadence/aot/quantizer/quantizer.py | 4 ++ .../cadence/aot/tests/test_quantizer_ops.py | 53 ++++++++++++++++++ 4 files changed, 129 insertions(+), 1 deletion(-) diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index 6b7990c0f2c..5375367b929 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -15,6 +15,8 @@ from executorch.backends.cadence.aot.quantizer.patterns import ( AddmmPattern, AddPattern, + AddReluPattern0, + AddReluPattern1, BmmPattern, CatPattern, Conv1dPattern, @@ -63,6 +65,7 @@ Conv2dReluPattern0, Conv2dReluPattern1, ) +AddReluPatterns = (AddReluPattern0, AddReluPattern1) def get_args_and_kwargs_add( @@ -616,7 +619,20 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 inputs_inputs + weights_inputs + other_inputs + bias_inputs ) kwargs = {} - if isinstance(pattern, AddPattern): + if isinstance(pattern, AddReluPatterns): + # For AddReLU, we are fusing Add+ReLU. + # The quantized_add op performs requantization, + # so the relu is implicit in the output quant params. + check_out_zero_point_is_min_range( + quant_node.args[2], quant_node.args[5] + ) + args, kwargs = get_args_and_kwargs_add( + graph_module, + inputs_inputs, + dequants_inputs, + quant_node, + ) + elif isinstance(pattern, AddPattern): args, kwargs = get_args_and_kwargs_add( graph_module, inputs_inputs, diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 2ce50871fc0..07aad18e36a 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -153,6 +153,61 @@ def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_add.per_tensor +# This is a base class for Add+ReLU fusion, since it can be used with two different relu aten ops +class AddReluBasePattern(QuantizationPattern): + @abstractmethod + def partition_types(self) -> List[OpOverload]: + pass + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> Tuple[PartitionAnchors, fx.Node]: + # The first node should be add, the second should be relu + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... + add_node = fused_partition[0].nodes[-1] + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... + relu_node = fused_partition[1].nodes[-1] + + # Bail if: + # - the add node is not a tensor add + # - the add node has kwargs (e.g. alpha) + is_tensor_add = isinstance(add_node.args[0], fx.Node) and isinstance( + add_node.args[1], fx.Node + ) + if not is_tensor_add or len(add_node.kwargs) > 0: + return ( + PartitionAnchors( + empty=True, + ), + add_node, + ) + + return ( + PartitionAnchors( + inputs=[(add_node, 0), (add_node, 1)], + weights=[], + biases=[], + output=[(relu_node,)], # Output is from the relu node + ), + relu_node, + ) + + def replacement_op(self) -> OpOverload: + return torch.ops.cadence.quantized_add.per_tensor + + +# Add + regular relu op fusion +class AddReluPattern0(AddReluBasePattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.add.Tensor, torch.ops.aten.relu.default] + + +# Add + alternate relu op fusion +class AddReluPattern1(AddReluBasePattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.add.Tensor, torch.ops.aten.relu_.default] + + class BmmPattern(QuantizationPattern): def partition_types(self) -> List[OpOverload]: return [torch.ops.aten.bmm.default] diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index 4edcd96e132..d521b9f83cf 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -13,6 +13,8 @@ from executorch.backends.cadence.aot.quantizer.patterns import ( AddmmPattern, AddPattern, + AddReluPattern0, + AddReluPattern1, BmmPattern, CatPattern, Conv1dPattern, @@ -398,6 +400,8 @@ def __init__( quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern1(), a8w8sym)) quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern0(), a8w8sym)) quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern1(), a8w8sym)) + quantizers.append(CadenceAtenQuantizer(AddReluPattern0(), a8w8)) + quantizers.append(CadenceAtenQuantizer(AddReluPattern1(), a8w8)) quantizers = quantizers + get_cadence_default_quantizers(is_qat=is_qat) quantizers.append(CadenceAtenQuantizer(AddPattern(), a8w8)) quantizers.append(CadenceAtenQuantizer(CatPattern(), a8w8)) diff --git a/backends/cadence/aot/tests/test_quantizer_ops.py b/backends/cadence/aot/tests/test_quantizer_ops.py index 06e2c08f4f4..dde26f06b7b 100644 --- a/backends/cadence/aot/tests/test_quantizer_ops.py +++ b/backends/cadence/aot/tests/test_quantizer_ops.py @@ -215,6 +215,15 @@ [qconfig_A8W8.input_activation], ), # CadenceFusedConvReluQuantizer test cases + ( + "fused_add_relu_A8W8", + lambda self: self._build_add_relu_graph(), + CadenceFusedConvReluQuantizer(), + torch.ops.aten.relu.default, + qconfig_A8W8.output_activation, + # For fused add+relu: both inputs are activations from add node + [qconfig_A8W8.input_activation, qconfig_A8W8.input_activation], + ), ( "fused_conv1d_relu_A8W8sym", lambda self: self._build_conv1d_relu_graph(), @@ -508,6 +517,50 @@ def _build_max_pool2d_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: ) return gm, max_pool_nodes[0] + def _build_add_relu_graph( + self, + ) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]: + """Build a graph with an add followed by relu (fused pattern). + + Returns: + A tuple of (graph_module, relu_node, add_node). + The relu_node is the target node where the annotation is placed. + The add_node is the input source node whose args contain the quantized inputs. + """ + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 10)) + y = builder.placeholder("y", torch.randn(1, 10)) + add = builder.call_operator( + op=torch.ops.aten.add.Tensor, + args=(x, y), + meta=NodeMetadata( + {"source_fn_stack": [("add", torch.ops.aten.add.Tensor)]} + ), + ) + relu = builder.call_operator( + op=torch.ops.aten.relu.default, + args=(add,), + meta=NodeMetadata( + {"source_fn_stack": [("relu", torch.ops.aten.relu.default)]} + ), + ) + builder.output([relu]) + gm = builder.get_graph_module() + + relu_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.relu.default, + ) + self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node") + + add_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.add.Tensor, + ) + self.assertEqual(len(add_nodes), 1, "Should find exactly one add node") + + return gm, relu_nodes[0], add_nodes[0] + def _build_conv2d_relu_graph( self, ) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]: From f9f29e7f50035078f96b8cdc1331ea183cf57814 Mon Sep 17 00:00:00 2001 From: Siddartha Pothapragada Date: Thu, 23 Apr 2026 15:48:46 -0700 Subject: [PATCH 07/17] Android: improve error diagnostics for LlmModule and exceptions (#19092) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add cause-chaining constructor to ExecutorchRuntimeException so wrapped exceptions preserve the original cause in the stack trace. Restore detailed native error messages in LlmModule.load() — the null runner case now reports the model_type_category and valid values instead of a generic message. Load failures now throw from JNI with the specific error code and description. This commit was authored with the help of Claude. --- .../ExecutorchRuntimeException.java | 5 +++++ extension/android/jni/jni_layer_llama.cpp | 21 +++++++++---------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java index e0fda73cc06..e72ed9e3d28 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java @@ -161,6 +161,11 @@ public ExecutorchRuntimeException(int errorCode, String details) { this.errorCode = errorCode; } + public ExecutorchRuntimeException(int errorCode, String details, Throwable cause) { + super(ErrorHelper.formatMessage(errorCode, details), cause); + this.errorCode = errorCode; + } + /** Returns the numeric error code from {@code runtime/core/error.h}. */ public int getErrorCode() { return errorCode; diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 2c0117dc576..94c0efff335 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -594,21 +595,19 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { jint load() { if (!runner_) { - ET_LOG( - Error, - "ExecuTorchLlmJni::load() called but runner_ is null. " - "The model runner was not created or failed to initialize due to a " - "previous configuration or initialization error. " - "Model type category: %d.", - model_type_category_); + std::stringstream ss; + ss << "Model runner was not created. model_type_category=" + << model_type_category_ + << ". Valid values: " << MODEL_TYPE_CATEGORY_LLM << " (LLM), " + << MODEL_TYPE_CATEGORY_MULTIMODAL << " (Multimodal)"; + executorch::jni_helper::throwExecutorchException( + static_cast(Error::InvalidState), ss.str().c_str()); return static_cast(Error::InvalidState); } const auto load_result = static_cast(runner_->load()); if (load_result != static_cast(Error::Ok)) { - ET_LOG( - Error, - "ExecuTorchLlmJni::load() failed in runner_->load() with error code %d.", - static_cast(load_result)); + executorch::jni_helper::throwExecutorchException( + static_cast(load_result), "Failed to load model runner"); } return load_result; } From 4a69750f557210ba2d222ae6f79be294c28c5e17 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Thu, 23 Apr 2026 18:37:19 -0500 Subject: [PATCH 08/17] Add Half (float16) support to slim ScalarType enum (#18959) (#18959) Summary: The CUDA runtime shims for sort operations use Half (float16) dtype, but it was not defined in the slim ScalarType enum, causing compiler warnings treated as errors (-Werror=switch). This adds proper Half support to the slim ScalarType enum so switch statements can use the enum value directly instead of casting to the underlying type. Differential Revision: D101218928 --- backends/aoti/slim/c10/core/ScalarType.h | 11 ++++-- .../slim/c10/core/test/test_scalar_type.cpp | 35 +++++++++++++++++++ backends/cuda/runtime/shims/sort.cu | 8 ++--- 3 files changed, 48 insertions(+), 6 deletions(-) diff --git a/backends/aoti/slim/c10/core/ScalarType.h b/backends/aoti/slim/c10/core/ScalarType.h index c1499a83f39..9a99aecf992 100644 --- a/backends/aoti/slim/c10/core/ScalarType.h +++ b/backends/aoti/slim/c10/core/ScalarType.h @@ -28,7 +28,7 @@ enum class ScalarType : int8_t { Short = 2, // int16_t Int = 3, // int32_t Long = 4, // int64_t - // Half = 5, // float16 - not currently needed + Half = 5, // float16 Float = 6, // float // Double = 7, // double - not currently needed // ComplexHalf = 8, @@ -48,6 +48,7 @@ constexpr ScalarType kChar = ScalarType::Char; constexpr ScalarType kShort = ScalarType::Short; constexpr ScalarType kInt = ScalarType::Int; constexpr ScalarType kLong = ScalarType::Long; +constexpr ScalarType kHalf = ScalarType::Half; constexpr ScalarType kFloat = ScalarType::Float; constexpr ScalarType kBool = ScalarType::Bool; constexpr ScalarType kBFloat16 = ScalarType::BFloat16; @@ -67,6 +68,8 @@ inline size_t elementSize(ScalarType t) { return sizeof(int32_t); case ScalarType::Long: return sizeof(int64_t); + case ScalarType::Half: + return 2; // sizeof(__half) = 2 bytes case ScalarType::Float: return sizeof(float); case ScalarType::Bool: @@ -93,6 +96,8 @@ inline const char* toString(ScalarType t) { return "Int"; case ScalarType::Long: return "Long"; + case ScalarType::Half: + return "Half"; case ScalarType::Float: return "Float"; case ScalarType::Bool: @@ -110,7 +115,8 @@ inline const char* toString(ScalarType t) { /// @param t The scalar type to check. /// @return true if the scalar type is floating point, false otherwise. inline bool isFloatingType(ScalarType t) { - return t == ScalarType::Float || t == ScalarType::BFloat16; + return t == ScalarType::Half || t == ScalarType::Float || + t == ScalarType::BFloat16; } /// Checks if the scalar type is an integral type (including bool optionally). @@ -149,6 +155,7 @@ inline bool isValidScalarType(ScalarType t) { case ScalarType::Short: case ScalarType::Int: case ScalarType::Long: + case ScalarType::Half: case ScalarType::Float: case ScalarType::Bool: case ScalarType::BFloat16: diff --git a/backends/aoti/slim/c10/core/test/test_scalar_type.cpp b/backends/aoti/slim/c10/core/test/test_scalar_type.cpp index 332f5d7d264..4c06f7ef101 100644 --- a/backends/aoti/slim/c10/core/test/test_scalar_type.cpp +++ b/backends/aoti/slim/c10/core/test/test_scalar_type.cpp @@ -36,6 +36,7 @@ const std::vector kAllScalarTypes = { {ScalarType::Short, 2, 2, "Short", false, true, true, false}, {ScalarType::Int, 3, 4, "Int", false, true, true, false}, {ScalarType::Long, 4, 8, "Long", false, true, true, false}, + {ScalarType::Half, 5, 2, "Half", true, false, false, false}, {ScalarType::Float, 6, 4, "Float", true, false, false, false}, {ScalarType::Bool, 11, 1, "Bool", false, false, true, true}, {ScalarType::BFloat16, 15, 2, "BFloat16", true, false, false, false}, @@ -128,6 +129,10 @@ TEST_F(ScalarTypeConstantsTest, KLongConstant) { EXPECT_EQ(kLong, ScalarType::Long); } +TEST_F(ScalarTypeConstantsTest, KHalfConstant) { + EXPECT_EQ(kHalf, ScalarType::Half); +} + TEST_F(ScalarTypeConstantsTest, KFloatConstant) { EXPECT_EQ(kFloat, ScalarType::Float); } @@ -185,6 +190,10 @@ TEST_F(ElementSizeConsistencyTest, LongMatchesSizeofInt64) { EXPECT_EQ(elementSize(ScalarType::Long), sizeof(int64_t)); } +TEST_F(ElementSizeConsistencyTest, HalfIs2Bytes) { + EXPECT_EQ(elementSize(ScalarType::Half), 2); +} + TEST_F(ElementSizeConsistencyTest, FloatMatchesSizeofFloat) { EXPECT_EQ(elementSize(ScalarType::Float), sizeof(float)); } @@ -196,3 +205,29 @@ TEST_F(ElementSizeConsistencyTest, BoolMatchesSizeofBool) { TEST_F(ElementSizeConsistencyTest, BFloat16MatchesSizeofBFloat16) { EXPECT_EQ(elementSize(ScalarType::BFloat16), sizeof(BFloat16)); } + +// ============================================================================= +// isValidScalarType Tests +// ============================================================================= + +class IsValidScalarTypeTest : public ::testing::Test {}; + +TEST_F(IsValidScalarTypeTest, HalfIsValid) { + EXPECT_TRUE(isValidScalarType(ScalarType::Half)); +} + +TEST_F(IsValidScalarTypeTest, AllSupportedTypesAreValid) { + EXPECT_TRUE(isValidScalarType(ScalarType::Byte)); + EXPECT_TRUE(isValidScalarType(ScalarType::Char)); + EXPECT_TRUE(isValidScalarType(ScalarType::Short)); + EXPECT_TRUE(isValidScalarType(ScalarType::Int)); + EXPECT_TRUE(isValidScalarType(ScalarType::Long)); + EXPECT_TRUE(isValidScalarType(ScalarType::Half)); + EXPECT_TRUE(isValidScalarType(ScalarType::Float)); + EXPECT_TRUE(isValidScalarType(ScalarType::Bool)); + EXPECT_TRUE(isValidScalarType(ScalarType::BFloat16)); +} + +TEST_F(IsValidScalarTypeTest, UndefinedIsNotValid) { + EXPECT_FALSE(isValidScalarType(ScalarType::Undefined)); +} diff --git a/backends/cuda/runtime/shims/sort.cu b/backends/cuda/runtime/shims/sort.cu index 804b5a55959..8d4a9771e62 100644 --- a/backends/cuda/runtime/shims/sort.cu +++ b/backends/cuda/runtime/shims/sort.cu @@ -24,8 +24,8 @@ namespace executorch::backends::cuda { namespace c10_slim = executorch::backends::aoti::slim::c10; -// PyTorch ScalarType::Half = 5, not defined in slim ScalarType enum. -constexpr auto kHalf = static_cast(5); +// PyTorch ScalarType::Half = 5, now defined in slim ScalarType enum. +using c10_slim::kHalf; namespace { @@ -188,7 +188,7 @@ AOTITorchError aoti_torch_cuda_sort_stable( case c10_slim::ScalarType::BFloat16: elem_size = sizeof(__nv_bfloat16); break; - case kHalf: + case c10_slim::ScalarType::Half: elem_size = sizeof(__half); break; default: @@ -387,7 +387,7 @@ AOTITorchError aoti_torch_cuda_sort_stable( stream); break; } - case kHalf: { + case c10_slim::ScalarType::Half: { sort_slice_impl( static_cast<__half*>(values_base) + offset, idx_ptr, From edb8c98af029138bfc4c6ed0c74ae4aedbda26f5 Mon Sep 17 00:00:00 2001 From: lucylq Date: Thu, 23 Apr 2026 17:12:22 -0700 Subject: [PATCH 09/17] Validate XNNPACK tensor flags are valid (#19102) 1. Attacker sets that flag on an external tensor. 2. xnnpack thinks the tensor is owned by itself, and frees it inside the backend. 3. et runtime also frees it at method destruction. Test Plan: Build and run executor runner against problematic PTE file: ``` # Build executor runner: cmake -B cmake-out \ -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON \ -DEXECUTORCH_BUILD_XNNPACK=ON cmake --build cmake-out -j16 --target executor_runner # Output (executorch) [lfq@devvm11764.nha0 /data/users/lfq/security/executorch (f9f29e7f)]$ ./cmake-out/executor_runner --model_path=/data/users/lfq/security/executorch_repros/TOB-EXECUTORCH-44.pte ``` Previous ``` (executorch) [lfq@devvm11764.nha0 /data/users/lfq/security/executorch (security44)]$ ./cmake-out/executor_runner --model_path=/data/users/lfq/security/executorch_repros/TOB-EXECUTORCH-44.pte Note (XNNPACK): l1_data_cache_bytes=32768, l1_data_cache_line_size=64, l1_data_cache_associativity=8, l1_data_cache_num_sets=64. (init_hardware_config, /data/users/lfq/security/executorch/backends/xnnpack/third-party/XNNPACK/src/configs/hardware-config.c:417) Note (XNNPACK): l2_data_cache_bytes=1048576, l2_data_cache_line_size=64, l2_data_cache_associativity=8, l2_data_cache_num_sets=2048. (init_hardware_config, /data/users/lfq/security/executorch/backends/xnnpack/third-party/XNNPACK/src/configs/hardware-config.c:436) I 00:00:00.002612 executorch:cpuinfo_utils.cpp:71] Reading file /sys/devices/soc0/image_version I 00:00:00.002640 executorch:cpuinfo_utils.cpp:87] Failed to open midr file /sys/devices/soc0/image_version I 00:00:00.002657 executorch:cpuinfo_utils.cpp:100] Reading file /sys/devices/system/cpu/cpu0/regs/identification/midr_el1 I 00:00:00.002664 executorch:cpuinfo_utils.cpp:109] Failed to open midr file /sys/devices/system/cpu/cpu0/regs/identification/midr_el1 I 00:00:00.002671 executorch:cpuinfo_utils.cpp:125] CPU info and manual query on # of cpus dont match. I 00:00:00.002672 executorch:executor_runner.cpp:223] Resetting threadpool with num threads = 0 I 00:00:00.002722 executorch:executor_runner.cpp:374] Model file /data/users/lfq/security/executorch_repros/TOB-EXECUTORCH-44.pte is loaded. I 00:00:00.002729 executorch:executor_runner.cpp:384] Using method forward I 00:00:00.002739 executorch:executor_runner.cpp:435] Setting up planned buffer 0, size 112. E 00:00:00.002806 executorch:XNNCompiler.cpp:331] Tensor value has unsupported flag bits 0xffffff00 E 00:00:00.002824 executorch:XNNPACKBackend.cpp:122] XNNCompiler::compileModel failed: 0x23 E 00:00:00.002827 executorch:method.cpp:127] Init failed for backend XnnpackBackend: 0x23 F 00:00:00.002830 executorch:executor_runner.cpp:459] In function main(), assert failed (method.ok()): Loading of method forward failed with status 0x23 Aborted (core dumped) ``` After, graceful error ``` (executorch) [lfq@devvm11764.nha0 /data/users/lfq/security/executorch (security44)]$ ./cmake-out/executor_runner --model_path=/data/users/lfq/security/executorch_repros/TOB-EXECUTORCH-44.pte Note (XNNPACK): l1_data_cache_bytes=32768, l1_data_cache_line_size=64, l1_data_cache_associativity=8, l1_data_cache_num_sets=64. (init_hardware_config, /data/users/lfq/security/executorch/backends/xnnpack/third-party/XNNPACK/src/configs/hardware-config.c:417) Note (XNNPACK): l2_data_cache_bytes=1048576, l2_data_cache_line_size=64, l2_data_cache_associativity=8, l2_data_cache_num_sets=2048. (init_hardware_config, /data/users/lfq/security/executorch/backends/xnnpack/third-party/XNNPACK/src/configs/hardware-config.c:436) I 00:00:00.002562 executorch:cpuinfo_utils.cpp:71] Reading file /sys/devices/soc0/image_version I 00:00:00.002595 executorch:cpuinfo_utils.cpp:87] Failed to open midr file /sys/devices/soc0/image_version I 00:00:00.002607 executorch:cpuinfo_utils.cpp:100] Reading file /sys/devices/system/cpu/cpu0/regs/identification/midr_el1 I 00:00:00.002618 executorch:cpuinfo_utils.cpp:109] Failed to open midr file /sys/devices/system/cpu/cpu0/regs/identification/midr_el1 I 00:00:00.002623 executorch:cpuinfo_utils.cpp:125] CPU info and manual query on # of cpus dont match. I 00:00:00.002628 executorch:executor_runner.cpp:223] Resetting threadpool with num threads = 0 I 00:00:00.002672 executorch:executor_runner.cpp:374] Model file /data/users/lfq/security/executorch_repros/TOB-EXECUTORCH-44.pte is loaded. I 00:00:00.002678 executorch:executor_runner.cpp:384] Using method forward I 00:00:00.002688 executorch:executor_runner.cpp:435] Setting up planned buffer 0, size 112. E 00:00:00.002750 executorch:XNNCompiler.cpp:331] Tensor value has unsupported flag bits 0xffffff00 E 00:00:00.002761 executorch:XNNPACKBackend.cpp:122] XNNCompiler::compileModel failed: 0x23 E 00:00:00.002769 executorch:method.cpp:127] Init failed for backend XnnpackBackend: 0x23 F 00:00:00.002772 executorch:executor_runner.cpp:459] In function main(), assert failed (method.ok()): Loading of method forward failed with status 0x23 ``` Co-authored-by: Github Executorch Co-authored-by: Claude --- backends/xnnpack/runtime/XNNCompiler.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 352d7af5a14..103bdeb6b82 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -319,6 +319,15 @@ Error defineTensor( ET_CHECK_OR_RETURN_ERROR( tensor_value != nullptr, InvalidProgram, "Deserialized tensor is null"); + // Validate that tensor_value->flags() is a subset of the allowed flags. + constexpr uint32_t kAllowedFlagsMask = + XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT; + ET_CHECK_OR_RETURN_ERROR( + (tensor_value->flags() & ~kAllowedFlagsMask) == 0, + InvalidProgram, + "Tensor value has unsupported flag bits 0x%x", + tensor_value->flags()); + // Get tensor dims, here we need to use a vector in order to properly // convert the uint32_t* to size_t*. Scalar tensors (rank 0) are permitted // to have a null dims vector; in that case dims_data is empty. From 75b31bbcf9eb95781b1c50aa16f47b96bd2240b0 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Thu, 23 Apr 2026 17:20:14 -0700 Subject: [PATCH 10/17] Fix smollm2 alias to point at SmolLM2-135M (v2) instead of SmolLM-135M (v1) (#18859) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The original SmolLM2 PR (#9354) started as v1 support, was renamed to `smollm2` during review, but the repo ID and `rope_theta` were never updated to v2 values. The two checkpoints are genuinely different models (0/272 tensors match). - `HUGGING_FACE_REPO_IDS["smollm2"]`: `HuggingFaceTB/SmolLM-135M` → `HuggingFaceTB/SmolLM2-135M` - `examples/models/smollm2/135M_config.json`: `rope_theta` `10000.0` → `100000.0` (matches [SmolLM2-135M HF config](https://huggingface.co/HuggingFaceTB/SmolLM2-135M/blob/main/config.json)) ### Test plan Data-only change (one string, one number). Verified values match the upstream HuggingFace SmolLM2-135M config. --- examples/models/llama/export_llama_lib.py | 2 +- examples/models/smollm2/135M_config.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 0f38191d807..9cf1b4b4bf0 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -123,7 +123,7 @@ "qwen2_5_1_5b": "Qwen/Qwen2.5-1.5B", "qwen2_5_coder_32b": "Qwen/Qwen2.5-Coder-32B-Instruct", "phi_4_mini": "microsoft/Phi-4-mini-instruct", - "smollm2": "HuggingFaceTB/SmolLM-135M", + "smollm2": "HuggingFaceTB/SmolLM2-135M", "qwen3_0_6b": "Qwen/Qwen3-0.6B", "qwen3_1_7b": "Qwen/Qwen3-1.7B", "qwen3_4b": "Qwen/Qwen3-4B", diff --git a/examples/models/smollm2/135M_config.json b/examples/models/smollm2/135M_config.json index 604c7e94ab5..1e3bc8ee0cb 100644 --- a/examples/models/smollm2/135M_config.json +++ b/examples/models/smollm2/135M_config.json @@ -6,7 +6,7 @@ "n_kv_heads": 3, "n_layers": 30, "norm_eps": 1e-05, - "rope_theta": 10000.0, + "rope_theta": 100000.0, "use_scaled_rope": false, "vocab_size": 49152, "use_hf_rope": false, From c3f3d127671f420ccbe560b695ad16c8a8117af9 Mon Sep 17 00:00:00 2001 From: lucylq Date: Thu, 23 Apr 2026 17:23:09 -0700 Subject: [PATCH 11/17] Add tryTo evalue accessors (#19039) Add tryTo accessors for each value. Previously, `toTensor` etc. abort with ET_CHECK_MSG when the type mismatches. API additions: - Per-type: tryToInt, tryToDouble, tryToBool, tryToScalar, tryToString, tryToTensor (already present, kept), tryToIntList, tryToBoolList, tryToDoubleList, tryToTensorList, tryToListOptionalTensor, tryToScalarType, tryToMemoryFormat, tryToLayout, tryToDevice. Tag mismatch returns Error::InvalidType; null list/string payload returns Error::InvalidState. - Templated tryTo() dispatcher mirroring to(), via a new EVALUE_DEFINE_TRY_TO macro kept adjacent to EVALUE_DEFINE_TO so drift between the two surfaces is visible at review time. - tryToOptional() widened from Tensor-only to generic, delegating to tryTo() so it works for any supported payload type. Tests cover success + mismatch paths for each new accessor, plus the widened tryToOptional() path. Authored-with: Claude --------- Co-authored-by: Github Executorch --- runtime/core/evalue.cpp | 25 +++ runtime/core/evalue.h | 252 ++++++++++++++++++++++++++++++ runtime/core/test/evalue_test.cpp | 167 ++++++++++++++++++++ 3 files changed, 444 insertions(+) diff --git a/runtime/core/evalue.cpp b/runtime/core/evalue.cpp index 121a9a29fa2..6fd118dadd0 100644 --- a/runtime/core/evalue.cpp +++ b/runtime/core/evalue.cpp @@ -10,6 +10,10 @@ namespace executorch { namespace runtime { + +// Specialize for list of optional tensors, as nullptr is a valid std::nullopt. +// For non-optional types, nullptr is invalid. + template <> executorch::aten::ArrayRef> BoxedEvalueList>::get() const { @@ -27,5 +31,26 @@ BoxedEvalueList>::get() const { return executorch::aten::ArrayRef>{ unwrapped_vals_, wrapped_vals_.size()}; } + +template <> +Result>> +BoxedEvalueList>::tryGet() const { + for (typename executorch::aten::ArrayRef< + std::optional>::size_type i = 0; + i < wrapped_vals_.size(); + i++) { + if (wrapped_vals_[i] == nullptr) { + unwrapped_vals_[i] = std::nullopt; + continue; + } + auto r = wrapped_vals_[i]->tryToOptional(); + if (!r.ok()) { + return r.error(); + } + unwrapped_vals_[i] = std::move(r.get()); + } + return executorch::aten::ArrayRef>{ + unwrapped_vals_, wrapped_vals_.size()}; +} } // namespace runtime } // namespace executorch diff --git a/runtime/core/evalue.h b/runtime/core/evalue.h index 8d75b1ace97..eed52bb74f7 100644 --- a/runtime/core/evalue.h +++ b/runtime/core/evalue.h @@ -8,6 +8,7 @@ #pragma once #include +#include #include #include @@ -71,6 +72,16 @@ class BoxedEvalueList { */ executorch::aten::ArrayRef get() const; + /** + * Result-returning counterpart of get(). Validates each wrapped EValue's + * tag before materializing; returns Error::InvalidType if any element's + * tag does not match T and Error::InvalidState if any element pointer is + * null. Use this when materializing lists from untrusted .pte data so that + * a malformed program cannot force a process abort inside to() / + * ET_CHECK. + */ + Result> tryGet() const; + /** * Destroys the unwrapped elements without re-dereferencing wrapped_vals_. * This is safe to call during EValue destruction because it does not @@ -107,6 +118,10 @@ template <> executorch::aten::ArrayRef> BoxedEvalueList>::get() const; +template <> +Result>> +BoxedEvalueList>::tryGet() const; + // Aggregate typing system similar to IValue only slimmed down with less // functionality, no dependencies on atomic, and fewer supported types to better // suit embedded systems (ie no intrusive ptr) @@ -193,6 +208,13 @@ struct EValue { return payload.copyable_union.as_int; } + Result tryToInt() const { + if (!isInt()) { + return Error::InvalidType; + } + return payload.copyable_union.as_int; + } + /****** Double Type ******/ /*implicit*/ EValue(double d) : tag(Tag::Double) { payload.copyable_union.as_double = d; @@ -207,6 +229,13 @@ struct EValue { return payload.copyable_union.as_double; } + Result tryToDouble() const { + if (!isDouble()) { + return Error::InvalidType; + } + return payload.copyable_union.as_double; + } + /****** Bool Type ******/ /*implicit*/ EValue(bool b) : tag(Tag::Bool) { payload.copyable_union.as_bool = b; @@ -221,6 +250,13 @@ struct EValue { return payload.copyable_union.as_bool; } + Result tryToBool() const { + if (!isBool()) { + return Error::InvalidType; + } + return payload.copyable_union.as_bool; + } + /****** Scalar Type ******/ /// Construct an EValue using the implicit value of a Scalar. /*implicit*/ EValue(executorch::aten::Scalar s) { @@ -256,6 +292,19 @@ struct EValue { } } + Result tryToScalar() const { + if (isDouble()) { + return executorch::aten::Scalar(payload.copyable_union.as_double); + } + if (isInt()) { + return executorch::aten::Scalar(payload.copyable_union.as_int); + } + if (isBool()) { + return executorch::aten::Scalar(payload.copyable_union.as_bool); + } + return Error::InvalidType; + } + /****** Tensor Type ******/ /*implicit*/ EValue(executorch::aten::Tensor t) : tag(Tag::Tensor) { // When built in aten mode, at::Tensor has a non trivial constructor @@ -305,6 +354,16 @@ struct EValue { return payload.as_tensor; } + // Returns a copy of the Tensor handle (one intrusive_ptr refcount bump in + // ATen mode; free in lean mode). Unlike toTensor()'s const& / & overloads, + // tryToTensor() cannot return a reference — Result wraps by value. + Result tryToTensor() const { + if (!isTensor()) { + return Error::InvalidType; + } + return payload.as_tensor; + } + /****** String Type ******/ /*implicit*/ EValue(executorch::aten::ArrayRef* s) : tag(Tag::String) { ET_CHECK_MSG(s != nullptr, "ArrayRef pointer cannot be null"); @@ -325,6 +384,18 @@ struct EValue { payload.copyable_union.as_string_ptr->size()); } + Result tryToString() const { + if (!isString()) { + return Error::InvalidType; + } + if (payload.copyable_union.as_string_ptr == nullptr) { + return Error::InvalidState; + } + return std::string_view( + payload.copyable_union.as_string_ptr->data(), + payload.copyable_union.as_string_ptr->size()); + } + /****** Int List Type ******/ /*implicit*/ EValue(BoxedEvalueList* i) : tag(Tag::ListInt) { ET_CHECK_MSG( @@ -344,6 +415,16 @@ struct EValue { return (payload.copyable_union.as_int_list_ptr)->get(); } + Result> tryToIntList() const { + if (!isIntList()) { + return Error::InvalidType; + } + if (payload.copyable_union.as_int_list_ptr == nullptr) { + return Error::InvalidState; + } + return (payload.copyable_union.as_int_list_ptr)->tryGet(); + } + /****** Bool List Type ******/ /*implicit*/ EValue(executorch::aten::ArrayRef* b) : tag(Tag::ListBool) { @@ -363,6 +444,16 @@ struct EValue { return *(payload.copyable_union.as_bool_list_ptr); } + Result> tryToBoolList() const { + if (!isBoolList()) { + return Error::InvalidType; + } + if (payload.copyable_union.as_bool_list_ptr == nullptr) { + return Error::InvalidState; + } + return *(payload.copyable_union.as_bool_list_ptr); + } + /****** Double List Type ******/ /*implicit*/ EValue(executorch::aten::ArrayRef* d) : tag(Tag::ListDouble) { @@ -382,6 +473,16 @@ struct EValue { return *(payload.copyable_union.as_double_list_ptr); } + Result> tryToDoubleList() const { + if (!isDoubleList()) { + return Error::InvalidType; + } + if (payload.copyable_union.as_double_list_ptr == nullptr) { + return Error::InvalidState; + } + return *(payload.copyable_union.as_double_list_ptr); + } + /****** Tensor List Type ******/ /*implicit*/ EValue(BoxedEvalueList* t) : tag(Tag::ListTensor) { @@ -402,6 +503,17 @@ struct EValue { return payload.copyable_union.as_tensor_list_ptr->get(); } + Result> tryToTensorList() + const { + if (!isTensorList()) { + return Error::InvalidType; + } + if (payload.copyable_union.as_tensor_list_ptr == nullptr) { + return Error::InvalidState; + } + return payload.copyable_union.as_tensor_list_ptr->tryGet(); + } + /****** List Optional Tensor Type ******/ /*implicit*/ EValue( BoxedEvalueList>* t) @@ -426,6 +538,17 @@ struct EValue { return payload.copyable_union.as_list_optional_tensor_ptr->get(); } + Result>> + tryToListOptionalTensor() const { + if (!isListOptionalTensor()) { + return Error::InvalidType; + } + if (payload.copyable_union.as_list_optional_tensor_ptr == nullptr) { + return Error::InvalidState; + } + return payload.copyable_union.as_list_optional_tensor_ptr->tryGet(); + } + /****** ScalarType Type ******/ executorch::aten::ScalarType toScalarType() const { ET_CHECK_MSG(isInt(), "EValue is not a ScalarType."); @@ -433,6 +556,14 @@ struct EValue { payload.copyable_union.as_int); } + Result tryToScalarType() const { + if (!isInt()) { + return Error::InvalidType; + } + return static_cast( + payload.copyable_union.as_int); + } + /****** MemoryFormat Type ******/ executorch::aten::MemoryFormat toMemoryFormat() const { ET_CHECK_MSG(isInt(), "EValue is not a MemoryFormat."); @@ -440,12 +571,27 @@ struct EValue { payload.copyable_union.as_int); } + Result tryToMemoryFormat() const { + if (!isInt()) { + return Error::InvalidType; + } + return static_cast( + payload.copyable_union.as_int); + } + /****** Layout Type ******/ executorch::aten::Layout toLayout() const { ET_CHECK_MSG(isInt(), "EValue is not a Layout."); return static_cast(payload.copyable_union.as_int); } + Result tryToLayout() const { + if (!isInt()) { + return Error::InvalidType; + } + return static_cast(payload.copyable_union.as_int); + } + /****** Device Type ******/ executorch::aten::Device toDevice() const { ET_CHECK_MSG(isInt(), "EValue is not a Device."); @@ -455,6 +601,16 @@ struct EValue { -1); } + Result tryToDevice() const { + if (!isInt()) { + return Error::InvalidType; + } + return executorch::aten::Device( + static_cast( + payload.copyable_union.as_int), + -1); + } + template T to() &&; template @@ -462,6 +618,15 @@ struct EValue { template typename internal::evalue_to_ref_overload_return::type to() &; + /** + * Result-returning equivalent of `to()`. Tag mismatch returns + * `Error::InvalidType`; a null list/string payload returns + * `Error::InvalidState`. Specializations are defined below via + * `EVALUE_DEFINE_TRY_TO`. + */ + template + Result tryTo() const; + /** * Converts the EValue to an optional object that can represent both T and * an uninitialized state. @@ -474,6 +639,23 @@ struct EValue { return this->to(); } + /** + * Result-returning equivalent of `toOptional()`. None maps to an empty + * optional; any other tag that doesn't match T propagates `tryTo()`'s + * error (`Error::InvalidType`). + */ + template + inline Result> tryToOptional() const { + if (this->isNone()) { + return std::optional(std::nullopt); + } + auto r = this->tryTo(); + if (!r.ok()) { + return r.error(); + } + return std::optional(std::move(r.get())); + } + private: // Pre cond: the payload value has had its destructor called void clearToNone() noexcept { @@ -591,6 +773,59 @@ EVALUE_DEFINE_TO( toListOptionalTensor) #undef EVALUE_DEFINE_TO +#define EVALUE_DEFINE_TRY_TO(T, method_name) \ + template <> \ + inline Result EValue::tryTo() const { \ + return this->method_name(); \ + } + +EVALUE_DEFINE_TRY_TO(executorch::aten::Scalar, tryToScalar) +EVALUE_DEFINE_TRY_TO(int64_t, tryToInt) +EVALUE_DEFINE_TRY_TO(bool, tryToBool) +EVALUE_DEFINE_TRY_TO(double, tryToDouble) +EVALUE_DEFINE_TRY_TO(std::string_view, tryToString) +EVALUE_DEFINE_TRY_TO(executorch::aten::ScalarType, tryToScalarType) +EVALUE_DEFINE_TRY_TO(executorch::aten::MemoryFormat, tryToMemoryFormat) +EVALUE_DEFINE_TRY_TO(executorch::aten::Layout, tryToLayout) +EVALUE_DEFINE_TRY_TO(executorch::aten::Device, tryToDevice) +// Tensor and Optional Tensor +EVALUE_DEFINE_TRY_TO(executorch::aten::Tensor, tryToTensor) +EVALUE_DEFINE_TRY_TO( + std::optional, + tryToOptional) + +// IntList and Optional IntList +EVALUE_DEFINE_TRY_TO(executorch::aten::ArrayRef, tryToIntList) +EVALUE_DEFINE_TRY_TO( + std::optional>, + tryToOptional>) + +// DoubleList and Optional DoubleList +EVALUE_DEFINE_TRY_TO(executorch::aten::ArrayRef, tryToDoubleList) +EVALUE_DEFINE_TRY_TO( + std::optional>, + tryToOptional>) + +// BoolList and Optional BoolList +EVALUE_DEFINE_TRY_TO(executorch::aten::ArrayRef, tryToBoolList) +EVALUE_DEFINE_TRY_TO( + std::optional>, + tryToOptional>) + +// TensorList and Optional TensorList +EVALUE_DEFINE_TRY_TO( + executorch::aten::ArrayRef, + tryToTensorList) +EVALUE_DEFINE_TRY_TO( + std::optional>, + tryToOptional>) + +// List of Optional Tensor +EVALUE_DEFINE_TRY_TO( + executorch::aten::ArrayRef>, + tryToListOptionalTensor) +#undef EVALUE_DEFINE_TRY_TO + template executorch::aten::ArrayRef BoxedEvalueList::get() const { for (typename executorch::aten::ArrayRef::size_type i = 0; @@ -602,6 +837,23 @@ executorch::aten::ArrayRef BoxedEvalueList::get() const { return executorch::aten::ArrayRef{unwrapped_vals_, wrapped_vals_.size()}; } +template +Result> BoxedEvalueList::tryGet() const { + for (typename executorch::aten::ArrayRef::size_type i = 0; + i < wrapped_vals_.size(); + i++) { + if (wrapped_vals_[i] == nullptr) { + return Error::InvalidState; + } + auto r = wrapped_vals_[i]->template tryTo(); + if (!r.ok()) { + return r.error(); + } + unwrapped_vals_[i] = std::move(r.get()); + } + return executorch::aten::ArrayRef{unwrapped_vals_, wrapped_vals_.size()}; +} + } // namespace runtime } // namespace executorch diff --git a/runtime/core/test/evalue_test.cpp b/runtime/core/test/evalue_test.cpp index edf6a1b12c1..1b0b86c1392 100644 --- a/runtime/core/test/evalue_test.cpp +++ b/runtime/core/test/evalue_test.cpp @@ -16,8 +16,12 @@ using namespace ::testing; +using executorch::aten::DeviceType; +using executorch::aten::Layout; +using executorch::aten::MemoryFormat; using executorch::aten::ScalarType; using executorch::runtime::BoxedEvalueList; +using executorch::runtime::Error; using executorch::runtime::EValue; using executorch::runtime::Tag; using executorch::runtime::testing::TensorFactory; @@ -214,6 +218,56 @@ TEST_F(EValueTest, BoxedEvalueList) { EXPECT_EQ(unwrapped[2], 3); } +TEST_F(EValueTest, BoxedEvalueListTryGetSuccess) { + EValue values[3] = { + EValue((int64_t)1), EValue((int64_t)2), EValue((int64_t)3)}; + EValue* values_p[3] = {&values[0], &values[1], &values[2]}; + int64_t storage[3] = {0, 0, 0}; + BoxedEvalueList x{values_p, storage, 3}; + auto result = x.tryGet(); + EXPECT_TRUE(result.ok()); + EXPECT_EQ(result->size(), 3); + EXPECT_EQ((*result)[0], 1); + EXPECT_EQ((*result)[1], 2); + EXPECT_EQ((*result)[2], 3); +} + +TEST_F(EValueTest, BoxedEvalueListTryGetWrongElementTag) { + // Second element is a Double, not an Int; tryGet should reject it rather + // than abort inside to(). + EValue values[3] = {EValue((int64_t)1), EValue(3.14), EValue((int64_t)3)}; + EValue* values_p[3] = {&values[0], &values[1], &values[2]}; + int64_t storage[3] = {0, 0, 0}; + BoxedEvalueList x{values_p, storage, 3}; + auto result = x.tryGet(); + EXPECT_EQ(result.error(), executorch::runtime::Error::InvalidType); +} + +TEST_F(EValueTest, BoxedEvalueListTryGetNullElement) { + // A null value is a malformed program for non-optional lists. + EValue a((int64_t)1); + EValue c((int64_t)3); + EValue* values_p[3] = {&a, nullptr, &c}; + int64_t storage[3] = {0, 0, 0}; + BoxedEvalueList x{values_p, storage, 3}; + auto result = x.tryGet(); + EXPECT_EQ(result.error(), executorch::runtime::Error::InvalidState); +} + +TEST_F(EValueTest, BoxedEvalueListTryGetOptionalTensorNullIsNone) { + // For optional, null value is valid. + EValue a; + EValue* values_p[2] = {&a, nullptr}; + std::optional storage[2]; + BoxedEvalueList> x{ + values_p, storage, 2}; + auto result = x.tryGet(); + EXPECT_TRUE(result.ok()); + EXPECT_EQ(result->size(), 2); + EXPECT_FALSE((*result)[0].has_value()); + EXPECT_FALSE((*result)[1].has_value()); +} + TEST_F(EValueTest, toOptionalTensorList) { // create list, empty evalue ctor gets tag::None EValue values[2] = {EValue(), EValue()}; @@ -417,3 +471,116 @@ TEST_F(EValueTest, toListOptionalTensorNullPointerCheck) { EXPECT_TRUE(e.isListOptionalTensor()); ET_EXPECT_DEATH({ e.toListOptionalTensor(); }, "pointer is null"); } + +// Per-type tryTo* coverage. +// For each type: +// - success and failure for named method tryTo[Int/Double/Bool/Tensor/..] +// - success and failure for templated tryTo() specialization + +TEST_F(EValueTest, TryToInt) { + EValue e_int(static_cast(42)); + EValue e_mismatch(3.14); + EXPECT_EQ(e_int.tryToInt().get(), 42); + EXPECT_EQ(e_mismatch.tryToInt().error(), Error::InvalidType); + EXPECT_EQ(e_int.tryTo().get(), 42); + EXPECT_EQ(e_mismatch.tryTo().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToDouble) { + EValue e_double(3.14); + EValue e_mismatch(static_cast(42)); + EXPECT_DOUBLE_EQ(e_double.tryToDouble().get(), 3.14); + EXPECT_EQ(e_mismatch.tryToDouble().error(), Error::InvalidType); + EXPECT_DOUBLE_EQ(e_double.tryTo().get(), 3.14); + EXPECT_EQ(e_mismatch.tryTo().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToBool) { + EValue e_bool(true); + EValue e_mismatch(static_cast(42)); + EXPECT_EQ(e_bool.tryToBool().get(), true); + EXPECT_EQ(e_mismatch.tryToBool().error(), Error::InvalidType); + EXPECT_EQ(e_bool.tryTo().get(), true); + EXPECT_EQ(e_mismatch.tryTo().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToTensor) { + TensorFactory tf; + EValue e_tensor(tf.ones({3, 2})); + EValue e_mismatch(static_cast(42)); + EXPECT_EQ(e_tensor.tryToTensor()->numel(), 6); + EXPECT_EQ(e_mismatch.tryToTensor().error(), Error::InvalidType); + EXPECT_EQ(e_tensor.tryTo()->numel(), 6); + EXPECT_EQ( + e_mismatch.tryTo().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToOptionalTensor) { + TensorFactory tf; + EValue e_tensor(tf.ones({3, 2})); + EValue e_none; + EValue e_mismatch(static_cast(42)); + // Named tryToOptional: value, None, mismatch. + auto r_val = e_tensor.tryToOptional(); + EXPECT_TRUE(r_val->has_value()); + EXPECT_EQ(r_val->value().numel(), 6); + EXPECT_FALSE(e_none.tryToOptional()->has_value()); + EXPECT_EQ( + e_mismatch.tryToOptional().error(), + Error::InvalidType); + // Templated tryTo>: None path. + EXPECT_FALSE( + e_none.tryTo>()->has_value()); +} + +TEST_F(EValueTest, TryToScalar) { + EValue e_int(static_cast(7)); + EValue e_double(2.5); + EValue e_bool(true); + EValue e_none; + EXPECT_EQ(e_int.tryToScalar()->to(), 7); + EXPECT_DOUBLE_EQ(e_double.tryToScalar()->to(), 2.5); + EXPECT_EQ(e_bool.tryToScalar()->to(), true); + // None is neither Int/Double/Bool. + EXPECT_EQ(e_none.tryToScalar().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToScalarType) { + EValue e(static_cast(ScalarType::Float)); + EValue e_mismatch(3.14); + EXPECT_EQ(e.tryToScalarType().get(), ScalarType::Float); + EXPECT_EQ(e_mismatch.tryToScalarType().error(), Error::InvalidType); + EXPECT_EQ(e.tryTo().get(), ScalarType::Float); + EXPECT_EQ(e_mismatch.tryTo().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToMemoryFormat) { + EValue e(static_cast(MemoryFormat::Contiguous)); + EValue e_mismatch(3.14); + EXPECT_EQ(e.tryToMemoryFormat().get(), MemoryFormat::Contiguous); + EXPECT_EQ(e_mismatch.tryToMemoryFormat().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToLayout) { + EValue e(static_cast(Layout::Strided)); + EValue e_mismatch(3.14); + EXPECT_EQ(e.tryToLayout().get(), Layout::Strided); + EXPECT_EQ(e_mismatch.tryToLayout().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToDevice) { + EValue e(static_cast(DeviceType::CPU)); + EValue e_mismatch(3.14); + EXPECT_EQ(e.tryToDevice().get().type(), DeviceType::CPU); + EXPECT_EQ(e_mismatch.tryToDevice().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToTensorList) { + EValue e(static_cast(42)); + EXPECT_EQ(e.tryToTensorList().error(), Error::InvalidType); +} + +TEST_F(EValueTest, TryToListOptionalTensor) { + EValue e(static_cast(42)); + EXPECT_EQ(e.tryToListOptionalTensor().error(), Error::InvalidType); +} From eef79219045d947d8608222d5fd4471c66ffc118 Mon Sep 17 00:00:00 2001 From: Hansong Zhang <107070759+kirklandsign@users.noreply.github.com> Date: Thu, 23 Apr 2026 21:29:32 -0700 Subject: [PATCH 12/17] Widen resolve_max_new_tokens parameters to int64_t and rename for clarity (#18917) Differential Revision: D99769848 Pull Request resolved: https://github.com/pytorch/executorch/pull/18917 --- docs/source/llm/run-with-c-plus-plus.md | 4 +- extension/llm/runner/_llm_runner.pyi | 5 ++- extension/llm/runner/irunner.h | 32 +++++++++------ extension/llm/runner/pybindings.cpp | 2 +- .../runner/test/test_generation_config.cpp | 40 +++++++++---------- .../llm/runner/test/test_runner_pybindings.py | 8 ++++ 6 files changed, 53 insertions(+), 38 deletions(-) diff --git a/docs/source/llm/run-with-c-plus-plus.md b/docs/source/llm/run-with-c-plus-plus.md index 217afad847b..b6c6082c3a6 100644 --- a/docs/source/llm/run-with-c-plus-plus.md +++ b/docs/source/llm/run-with-c-plus-plus.md @@ -183,13 +183,13 @@ struct GenerationConfig { int32_t num_eos = 0; // Number of EOS tokens to add // Helper method to resolve the actual max_new_tokens based on constraints - int32_t resolve_max_new_tokens(int32_t max_context_len, int32_t num_prompt_tokens) const; + int32_t resolve_max_new_tokens(int64_t max_context_len, int64_t num_tokens_occupied) const; }; ``` The `resolve_max_new_tokens` method handles the logic of determining how many tokens can be generated based on: - The model's maximum context length -- The number of tokens in the prompt +- The number of token positions already occupied in the context window - The user-specified maximum sequence length and maximum new tokens ### Implementation Components diff --git a/extension/llm/runner/_llm_runner.pyi b/extension/llm/runner/_llm_runner.pyi index 20333578763..271cf1e1540 100644 --- a/extension/llm/runner/_llm_runner.pyi +++ b/extension/llm/runner/_llm_runner.pyi @@ -47,14 +47,15 @@ class GenerationConfig: ... def resolve_max_new_tokens( - self, max_context_len: int, num_prompt_tokens: int + self, max_context_len: int, num_tokens_occupied: int ) -> int: """ Resolve the maximum number of new tokens to generate based on constraints. Args: max_context_len: The maximum context length supported by the model - num_prompt_tokens: The number of tokens in the input prompt + num_tokens_occupied: The number of token positions already occupied + in the context window (e.g. pos after prefill) Returns: The resolved maximum number of new tokens to generate diff --git a/extension/llm/runner/irunner.h b/extension/llm/runner/irunner.h index 0fcce1f37e4..bb7dd767fea 100644 --- a/extension/llm/runner/irunner.h +++ b/extension/llm/runner/irunner.h @@ -10,6 +10,7 @@ #pragma once +#include #include #include #include @@ -65,36 +66,41 @@ struct GenerationConfig { * * This method calculates the maximum number of new tokens that can be * generated considering both seq_len and max_new_tokens constraints, as well - * as the model's maximum context length and the number of tokens in the - * prompt. + * as the model's maximum context length and how many token positions are + * already occupied (e.g. by prior turns and the current prompt). * * @param max_context_len The maximum context length supported by the model - * @param num_prompt_tokens The number of tokens in the input prompt + * @param num_tokens_occupied The number of token positions already occupied + * in the context window (e.g. pos_ after prefill) * @return The resolved maximum number of new tokens to generate */ int32_t resolve_max_new_tokens( - int32_t max_context_len, - int32_t num_prompt_tokens) const { - int32_t result; + int64_t max_context_len, + int64_t num_tokens_occupied) const { + int64_t result; if (seq_len == -1 && max_new_tokens == -1) { - // Both are -1, use max context len minus prompt tokens - result = max_context_len - num_prompt_tokens; + // Both are -1, use max context len minus occupied tokens + result = max_context_len - num_tokens_occupied; } else if (seq_len == -1 && max_new_tokens != -1) { // Only max_new_tokens is specified - result = std::min(max_new_tokens, max_context_len - num_prompt_tokens); + result = std::min( + static_cast(max_new_tokens), + max_context_len - num_tokens_occupied); } else if (seq_len != -1 && max_new_tokens == -1) { // Only seq_len is specified - result = std::min(seq_len, max_context_len) - num_prompt_tokens; + result = std::min(static_cast(seq_len), max_context_len) - + num_tokens_occupied; } else { // Both are specified result = std::min( - std::min(seq_len, max_context_len) - num_prompt_tokens, - max_new_tokens); + std::min(static_cast(seq_len), max_context_len) - + num_tokens_occupied, + static_cast(max_new_tokens)); } // Ensure result is not negative - return std::max(0, result); + return static_cast(std::max(static_cast(0), result)); } }; diff --git a/extension/llm/runner/pybindings.cpp b/extension/llm/runner/pybindings.cpp index ecd49e6341a..3188b5390c4 100644 --- a/extension/llm/runner/pybindings.cpp +++ b/extension/llm/runner/pybindings.cpp @@ -297,7 +297,7 @@ PYBIND11_MODULE(_llm_runner, m) { "resolve_max_new_tokens", &GenerationConfig::resolve_max_new_tokens, py::arg("max_context_len"), - py::arg("num_prompt_tokens"), + py::arg("num_tokens_occupied"), "Resolve the maximum number of new tokens to generate based on constraints") .def("__repr__", [](const GenerationConfig& config) { return " Date: Thu, 23 Apr 2026 23:12:15 -0700 Subject: [PATCH 13/17] skip cuda operations when running qwen 3.5 moe on other backend (#19095) This PR makes GPU related operator cuda-backend specific, to bring metal qwen 3.5 moe ci back --- examples/models/qwen3_5_moe/main.cpp | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/examples/models/qwen3_5_moe/main.cpp b/examples/models/qwen3_5_moe/main.cpp index bae4cfc183c..00c91a685e1 100644 --- a/examples/models/qwen3_5_moe/main.cpp +++ b/examples/models/qwen3_5_moe/main.cpp @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -58,11 +59,13 @@ int main(int argc, char** argv) { llm::Stats stats; +#ifdef EXECUTORCH_BUILD_CUDA // GPU memory before load - size_t gpu_free_bytes, gpu_total_bytes; + size_t gpu_free_bytes = 0, gpu_total_bytes = 0; cudaMemGetInfo(&gpu_free_bytes, &gpu_total_bytes); stats.gpu_total_bytes = gpu_total_bytes; stats.gpu_free_before_load_bytes = gpu_free_bytes; +#endif stats.model_load_start_ms = llm::time_in_ms(); @@ -127,9 +130,11 @@ int main(int argc, char** argv) { stats.model_load_end_ms = llm::time_in_ms(); +#ifdef EXECUTORCH_BUILD_CUDA // GPU memory after load cudaMemGetInfo(&gpu_free_bytes, &gpu_total_bytes); stats.gpu_free_after_load_bytes = gpu_free_bytes; +#endif // Get EOS ids auto eos_ids = llm::get_eos_ids(tokenizer.get(), module.get()); @@ -155,7 +160,7 @@ int main(int argc, char** argv) { } auto prompt_tokens = std::move(*encode_result); int64_t num_prompt_tokens = prompt_tokens.size(); - printf("Prompt tokens: %ld\n", num_prompt_tokens); + printf("Prompt tokens: %" PRId64 "\n", num_prompt_tokens); stats.num_prompt_tokens = num_prompt_tokens; stats.inference_start_ms = llm::time_in_ms(); @@ -209,7 +214,7 @@ int main(int argc, char** argv) { double prefill_ms = (double)(stats.prompt_eval_end_ms - stats.inference_start_ms); printf( - "Prefill: %ld tokens in %.1f ms (%.1f tok/s)\n", + "Prefill: %" PRId64 " tokens in %.1f ms (%.1f tok/s)\n", num_prompt_tokens, prefill_ms, num_prompt_tokens * 1000.0 / prefill_ms); @@ -290,17 +295,19 @@ int main(int argc, char** argv) { double decode_ms = (double)(stats.inference_end_ms - stats.prompt_eval_end_ms); printf( - "Decode: %ld tokens in %.1f ms (%.1f tok/s)\n", + "Decode: %" PRId64 " tokens in %.1f ms (%.1f tok/s)\n", num_generated, decode_ms, num_generated * 1000.0 / decode_ms); - printf("Prompt tokens: %ld\n", num_prompt_tokens); + printf("Prompt tokens: %" PRId64 "\n", num_prompt_tokens); +#ifdef EXECUTORCH_BUILD_CUDA // GPU memory after generation cudaMemGetInfo(&gpu_free_bytes, &gpu_total_bytes); stats.gpu_free_after_generate_bytes = gpu_free_bytes; stats.gpu_peak_usage_mb = (stats.gpu_total_bytes - gpu_free_bytes) / 1024.0 / 1024.0; +#endif llm::print_report(stats); From b6a47aa93c34385cc2d7507e453962ceb94ad01d Mon Sep 17 00:00:00 2001 From: Oscar Andersson <87121123+oscarandersson8218@users.noreply.github.com> Date: Fri, 24 Apr 2026 12:35:18 +0200 Subject: [PATCH 14/17] Arm backend: Disable fusing of TOSA ops (#19066) Disable fusing of ops that have symbolic shapes as arguments. Also disable fusing of TOSA dialect ops. cc @digantdesai @freddan80 @per @zingo @mansnils @Sebastian-Larsson @robell Signed-off-by: Oscar Andersson --- .../arm/_passes/fuse_constant_ops_pass.py | 49 +++++++++--- .../passes/test_fuse_constant_ops_pass.py | 77 ++++++++++++++++++- 2 files changed, 116 insertions(+), 10 deletions(-) diff --git a/backends/arm/_passes/fuse_constant_ops_pass.py b/backends/arm/_passes/fuse_constant_ops_pass.py index d6fd4b18b53..f54ed851240 100644 --- a/backends/arm/_passes/fuse_constant_ops_pass.py +++ b/backends/arm/_passes/fuse_constant_ops_pass.py @@ -4,7 +4,8 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Set, Type +from collections.abc import Mapping +from typing import Sequence, Set, Type import torch._export.utils import torch.fx @@ -18,6 +19,7 @@ from executorch.backends.arm._passes.fuse_equal_placeholders_pass import ( FuseEqualPlaceholdersPass, ) +from executorch.backends.arm.tosa.dialect.shape import meta_has_shape_mark from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.backends.transforms.utils import ( create_constant_placeholder, @@ -53,6 +55,36 @@ def __init__(self, exported_program: ExportedProgram, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.exported_program = exported_program + @staticmethod + def _is_tosa_dialect_op(target) -> bool: + target_str = str(target) + return ( + "executorch.exir.dialects.backend._ops.tosa." in target_str + or " bool: + if isinstance(arg, torch.fx.Node): + if meta_has_shape_mark(arg.meta): + return True + return FuseConstantArgsPass._arg_contains_symbolic_shape( + arg.meta.get("val") + ) + if isinstance(arg, torch.SymInt): + return True + if isinstance(arg, Mapping): + return any( + FuseConstantArgsPass._arg_contains_symbolic_shape(k) + or FuseConstantArgsPass._arg_contains_symbolic_shape(v) + for k, v in arg.items() + ) + if isinstance(arg, Sequence) and not isinstance(arg, (str, bytes)): + return any( + FuseConstantArgsPass._arg_contains_symbolic_shape(v) for v in arg + ) + return False + def _propagate_special_dtype(self, from_nodes, to_node, data): """Propagate special dtype meta if it exists.""" special_dtypes = set() @@ -142,13 +174,13 @@ def call(self, graph_module): for node in graph_module.graph.nodes: if node.op != "call_function": continue - if node.target in [ - exir_ops.backend.tosa.MATMUL.default, - exir_ops.backend.tosa.RESCALE.default, - exir_ops.backend.tosa.RESIZE.default, - exir_ops.backend.tosa.TABLE.default, - exir_ops.backend.tosa.TRANSPOSE.default, - ]: + # Don't fuse TOSA dialect ops as they do not have eager forward functions. + # Also don't fuse ops whose explicit args/kwargs include symbolic shape values. + if ( + self._is_tosa_dialect_op(node.target) + or self._arg_contains_symbolic_shape(node.args) + or self._arg_contains_symbolic_shape(node.kwargs) + ): continue input_nodes = node.all_input_nodes @@ -164,7 +196,6 @@ def call(self, graph_module): ) if not all(input_nodes_constant): continue - try: did_fuse = self._fuse_nodes(node) if did_fuse: diff --git a/backends/arm/test/passes/test_fuse_constant_ops_pass.py b/backends/arm/test/passes/test_fuse_constant_ops_pass.py index 785744c1b37..d915b4ecba0 100644 --- a/backends/arm/test/passes/test_fuse_constant_ops_pass.py +++ b/backends/arm/test/passes/test_fuse_constant_ops_pass.py @@ -6,6 +6,7 @@ import operator from typing import cast, ClassVar, Dict, Protocol, Tuple +import executorch.backends.arm.tosa.dialect # noqa: F401 import torch from executorch.backends.arm._passes.fuse_constant_ops_pass import ( ComputeConstantOpsAOTPass, @@ -15,8 +16,15 @@ from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.backends.arm.test.tester.test_pipeline import PassPipeline -from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.arm.tosa.mapping import TosaSpecialDtype +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) from executorch.backends.test.harness.stages import StageType +from executorch.backends.test.program_builder import ProgramBuilder +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export.graph_signature import InputKind input_t = Tuple[torch.Tensor] # Input x input_t2 = Tuple[torch.Tensor, torch.Tensor] @@ -270,3 +278,70 @@ def test_fuse_constant_args_tosa_INT_cat_uses_top_level_arg_qparams() -> None: for node in pass_result.graph_module.graph.nodes if node.op == "placeholder" ] == ["aten_cat_default_fused_const"] + + +def test_fuse_constant_args_identifies_tosa_dialect_targets() -> None: + class FakeTosaTarget: + def __str__(self) -> str: + return "executorch.exir.dialects.backend._ops.tosa.MAX_POOL2D.default" + + assert FuseConstantArgsPass._is_tosa_dialect_op(FakeTosaTarget()) + assert FuseConstantArgsPass._is_tosa_dialect_op( + exir_ops.backend.tosa.GATHER.default + ) + assert not FuseConstantArgsPass._is_tosa_dialect_op(torch.ops.aten.add.Tensor) + + +def test_fuse_constant_args_identifies_symbolic_shape_args() -> None: + graph = torch.fx.Graph() + shape_node = graph.placeholder("shape") + shape_node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.SHAPE + + assert FuseConstantArgsPass._arg_contains_symbolic_shape((shape_node, [1, 2])) + assert not FuseConstantArgsPass._arg_contains_symbolic_shape( + ([1, 2], {"pad": (0, 0)}) + ) + + +def test_fuse_constant_args_skips_backend_tosa_gather(caplog) -> None: + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.1+FP+shape")): + builder = ProgramBuilder() + values = builder.placeholder( + "values", + torch.randn(1, 4, 3), + input_kind=InputKind.CONSTANT_TENSOR, + ) + indices = builder.placeholder( + "indices", + torch.tensor([[0, 2]], dtype=torch.int32), + input_kind=InputKind.CONSTANT_TENSOR, + ) + gather = builder.call_operator( + exir_ops.backend.tosa.GATHER.default, + (values, indices), + ) + builder.output([gather]) + + exported_program = builder.get_program() + graph_module = exported_program.graph_module + + with caplog.at_level("WARNING"): + FuseConstantArgsPass(exported_program)(graph_module) + + warning_messages = [ + record.getMessage() + for record in caplog.records + if record.name == "executorch.backends.arm._passes.fuse_constant_ops_pass" + ] + assert not any( + "Failed to fuse constant op" in message and "GATHER" in message + for message in warning_messages + ) + assert ( + sum( + node.op == "call_function" + and node.target == exir_ops.backend.tosa.GATHER.default + for node in graph_module.graph.nodes + ) + == 1 + ) From c5c5b3a5d59263868748e757cf73376eb57a2450 Mon Sep 17 00:00:00 2001 From: Oscar Andersson <87121123+oscarandersson8218@users.noreply.github.com> Date: Fri, 24 Apr 2026 12:35:43 +0200 Subject: [PATCH 15/17] Arm backend: Add util for symbolic range eval (#19108) Adds util for computing a value range from a symbolic expression. cc @digantdesai @freddan80 @per @zingo @mansnils @Sebastian-Larsson @robell Signed-off-by: Oscar Andersson --- backends/arm/_passes/rewrite_conv_pass.py | 15 +- .../arm/_passes/size_adjust_input_pass.py | 7 + backends/arm/_passes/symbolic_value_range.py | 138 ++++++++++++++++++ .../test/passes/test_symbolic_value_range.py | 69 +++++++++ 4 files changed, 226 insertions(+), 3 deletions(-) create mode 100644 backends/arm/_passes/symbolic_value_range.py create mode 100644 backends/arm/test/passes/test_symbolic_value_range.py diff --git a/backends/arm/_passes/rewrite_conv_pass.py b/backends/arm/_passes/rewrite_conv_pass.py index e4be0b5dc25..8244dc2558b 100644 --- a/backends/arm/_passes/rewrite_conv_pass.py +++ b/backends/arm/_passes/rewrite_conv_pass.py @@ -21,6 +21,9 @@ get_input_qparams, get_output_qparams, ) +from executorch.backends.arm._passes.symbolic_value_range import ( + evaluate_symbolic_expr_values, +) from executorch.backends.arm.constants import HWCM_ORDER, NHWC_INVERSE_ORDER from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.backends.arm.tosa.specification import get_context_shape_env @@ -83,8 +86,14 @@ def _adjust_pad_if_needed( if isinstance(mod_remainder, torch.SymInt): shape_env = get_context_shape_env() - value_ranges = shape_env.bound_sympy(mod_remainder.node.expr) - mod_remainder_upper = int(value_ranges.upper) + exact_values = evaluate_symbolic_expr_values( + mod_remainder.node.expr, shape_env + ) + if exact_values is not None: + mod_remainder_upper = max(exact_values) + else: + value_ranges = shape_env.bound_sympy(mod_remainder.node.expr) + mod_remainder_upper = int(value_ranges.upper) if mod_remainder_upper == 0: mod_remainder = 0 else: @@ -92,7 +101,7 @@ def _adjust_pad_if_needed( if mod_remainder_upper > pad: raise RuntimeError( - "This case should be handled by the SizeAdjustInputPass, is it enabled?" + "This case should be handled by the SizeAdjustInputPass, is it enabled?\n" ) return pad - mod_remainder diff --git a/backends/arm/_passes/size_adjust_input_pass.py b/backends/arm/_passes/size_adjust_input_pass.py index 233c93340b8..bf50306f5d6 100644 --- a/backends/arm/_passes/size_adjust_input_pass.py +++ b/backends/arm/_passes/size_adjust_input_pass.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from typing import cast, Sequence, Set, Type, TypeAlias +import torch import torch.fx from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( @@ -12,6 +13,9 @@ ) from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass from executorch.backends.arm._passes.rewrite_max_pool2d_pass import RewriteMaxPool2dPass +from executorch.backends.arm._passes.symbolic_value_range import ( + evaluate_symbolic_expr_values, +) from executorch.backends.arm.tosa.specification import get_context_shape_env from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -49,6 +53,9 @@ def _greater_than(input: SymIntLike, other: int) -> bool | torch.SymBool: """Returns whether an int or SymInt is greater than another value.""" if isinstance(input, torch.SymInt): shape_env = get_context_shape_env() + exact_values = evaluate_symbolic_expr_values(input.node.expr, shape_env) + if exact_values is not None: + return max(exact_values) > other value_ranges = shape_env.bound_sympy(input.node.expr) return value_ranges.upper > other else: diff --git a/backends/arm/_passes/symbolic_value_range.py b/backends/arm/_passes/symbolic_value_range.py new file mode 100644 index 00000000000..0753fefa270 --- /dev/null +++ b/backends/arm/_passes/symbolic_value_range.py @@ -0,0 +1,138 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import sympy # type: ignore[import-untyped] +import torch +from torch.fx.experimental.symbolic_shapes import ShapeEnv +from torch.utils._sympy.interp import sympy_interp + +_MAX_SET_SIZE = 256 +_ExactValues = Optional[frozenset[sympy.Basic]] + + +def _expr_to_int(sym_expr: sympy.Basic) -> Optional[int]: + if isinstance(sym_expr, int): + return sym_expr + if isinstance(sym_expr, sympy.Integer): + return int(sym_expr) + if getattr(sym_expr, "is_integer", False) and sym_expr.is_number: + return int(sym_expr) + return None + + +def _symbol_values(symbol: sympy.Symbol, shape_env: ShapeEnv) -> _ExactValues: + value_range = shape_env.var_to_range.get(symbol) + if value_range is None or not value_range.is_int: + return None + + lower = _expr_to_int(value_range.lower) + upper = _expr_to_int(value_range.upper) + if lower is None or upper is None or upper < lower: + return None + if upper - lower + 1 > _MAX_SET_SIZE: + return None + + return frozenset(sympy.Integer(value) for value in range(lower, upper + 1)) + + +def _map_values(values: _ExactValues, fn) -> _ExactValues: + if values is None: + return None + + result = {sympy.simplify(fn(value)) for value in values} + if len(result) > _MAX_SET_SIZE: + return None + return frozenset(result) + + +def _combine_values(lhs: _ExactValues, rhs: _ExactValues, fn) -> _ExactValues: + if lhs is None or rhs is None: + return None + if len(lhs) * len(rhs) > _MAX_SET_SIZE * _MAX_SET_SIZE: + return None + + result = {sympy.simplify(fn(a, b)) for a in lhs for b in rhs} + if len(result) > _MAX_SET_SIZE: + return None + return frozenset(result) + + +class _ExactValueAnalysis: + @staticmethod + def constant(value, dtype) -> frozenset[sympy.Basic]: + return frozenset({sympy.sympify(value)}) + + @staticmethod + def add(lhs: _ExactValues, rhs: _ExactValues) -> _ExactValues: + return _combine_values(lhs, rhs, lambda a, b: a + b) + + @staticmethod + def mul(lhs: _ExactValues, rhs: _ExactValues) -> _ExactValues: + return _combine_values(lhs, rhs, lambda a, b: a * b) + + @staticmethod + def mod(lhs: _ExactValues, rhs: _ExactValues) -> _ExactValues: + if rhs is None or any(value == 0 for value in rhs): + return None + return _combine_values(lhs, rhs, lambda a, b: sympy.Mod(a, b)) + + @staticmethod + def pow(lhs: _ExactValues, rhs: _ExactValues) -> _ExactValues: + return _combine_values(lhs, rhs, lambda a, b: a**b) + + @staticmethod + def floor_to_int(values: _ExactValues, dtype) -> _ExactValues: + return _map_values(values, sympy.floor) + + @staticmethod + def sym_sum(args: list[_ExactValues]) -> _ExactValues: + acc: _ExactValues = frozenset({sympy.Integer(0)}) + for arg in args: + acc = _ExactValueAnalysis.add(acc, arg) + if acc is None: + return None + return acc + + +def evaluate_symbolic_expr_values( + expr: sympy.Basic | torch.SymInt, + shape_env: ShapeEnv, +) -> Optional[set[int]]: + """Return a best-effort finite set of possible integer values. + + The helper first relies on ``bound_sympy`` for cheap singleton detection. + When interval bounds are not precise enough, it falls back to a small + exact-set analysis over bounded symbols using ``sympy_interp``. + + """ + root_expr = sympy.simplify( + expr.node.expr if isinstance(expr, torch.SymInt) else expr + ) + value_range = shape_env.bound_sympy(root_expr) + if value_range.is_int and value_range.is_singleton(): + singleton = _expr_to_int(value_range.lower) + return {singleton} if singleton is not None else None + + exact_values = sympy_interp( + _ExactValueAnalysis, + { + symbol: _symbol_values(symbol, shape_env) + for symbol in root_expr.free_symbols + }, + root_expr, + missing_handler=lambda symbol: _symbol_values(symbol, shape_env), + ) + if exact_values is None: + return None + + result: set[int] = set() + for value in exact_values: + integer_value = _expr_to_int(sympy.simplify(value)) + if integer_value is None: + return None + result.add(integer_value) + return result diff --git a/backends/arm/test/passes/test_symbolic_value_range.py b/backends/arm/test/passes/test_symbolic_value_range.py new file mode 100644 index 00000000000..8d3c970f0ab --- /dev/null +++ b/backends/arm/test/passes/test_symbolic_value_range.py @@ -0,0 +1,69 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sympy # type: ignore[import-untyped] +import torch +from executorch.backends.arm._passes.symbolic_value_range import ( + evaluate_symbolic_expr_values, +) +from torch.fx.experimental.symbolic_shapes import ShapeEnv + + +def _make_shape_env( + *, + symbol_name: str = "s89", + hint: int = 2, + compiler_min: int = 1, + compiler_max: int = 2, +) -> tuple[ShapeEnv, torch.SymInt]: + shape_env = ShapeEnv() + symint = shape_env.create_symintnode(sympy.Symbol(symbol_name), hint=hint) + shape_env.constrain_symbol_range( + symint.node.expr, + compiler_min=compiler_min, + compiler_max=compiler_max, + ) + return shape_env, symint + + +def test_evaluate_symbolic_expr_values_returns_singleton_for_constant_expr() -> None: + shape_env, symint = _make_shape_env() + + assert evaluate_symbolic_expr_values( + symint.node.expr - symint.node.expr, shape_env + ) == {0} + assert evaluate_symbolic_expr_values( + sympy.floor(symint.node.expr / symint.node.expr), shape_env + ) == {1} + + +def test_evaluate_symbolic_expr_values_returns_singleton_for_singleton_symint() -> None: + shape_env, symint = _make_shape_env(hint=3, compiler_min=3, compiler_max=3) + + assert evaluate_symbolic_expr_values(symint, shape_env) == {3} + assert evaluate_symbolic_expr_values(symint.node.expr, shape_env) == {3} + + +def test_evaluate_symbolic_expr_values_enumerates_non_singleton_symint() -> None: + shape_env, symint = _make_shape_env(hint=3, compiler_min=2, compiler_max=6) + + assert evaluate_symbolic_expr_values(symint, shape_env) == {2, 3, 4, 5, 6} + assert evaluate_symbolic_expr_values(symint.node.expr, shape_env) == {2, 3, 4, 5, 6} + + +def test_evaluate_symbolic_expr_values_tracks_exact_modulo_residue() -> None: + shape_env, symint = _make_shape_env(hint=3, compiler_min=2, compiler_max=6) + expr = sympy.Mod(16 * symint.node.expr - 7, 4) + + value_range = shape_env.bound_sympy(expr) + assert value_range.lower == 0 + assert value_range.upper == 3 + assert evaluate_symbolic_expr_values(expr, shape_env) == {1} + + +def test_evaluate_symbolic_expr_values_bails_out_for_large_symbol_ranges() -> None: + shape_env, symint = _make_shape_env(hint=3, compiler_min=1, compiler_max=400) + + assert evaluate_symbolic_expr_values(symint, shape_env) is None From 98a1d6626d165e4cc48ba47a3e0772c59c4a1a90 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Fri, 24 Apr 2026 15:07:41 +0200 Subject: [PATCH 16/17] Remove un-used copy from building dockers (#18868) The removed copy seems to be stale, it is never used. --- .ci/docker/build.sh | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 08cab0587e4..7c4a80044e4 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -1,6 +1,7 @@ #!/bin/bash # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -94,11 +95,6 @@ BUILD_DOCS=1 # Copy requirements-lintrunner.txt from root to here cp ../../requirements-lintrunner.txt ./ -# Copy arm setup script from root to here -# TODO(huydhn): Figure out a way to rebuild the Docker image automatically -# with a new image hash when the content here is updated -cp -r ../../examples/arm/ ./arm - docker build \ --no-cache \ --progress=plain \ From 476a7ef427cc4b78c8767b7ed6f3b7db82642867 Mon Sep 17 00:00:00 2001 From: Phineas1500 <41450967+Phineas1500@users.noreply.github.com> Date: Fri, 24 Apr 2026 10:47:55 -0400 Subject: [PATCH 17/17] Add recurrent gated delta rule custom op for Qwen3.5 attention (#18088) ## Summary This PR adds a fused `llama::recurrent_gated_delta_rule` custom op and wires Qwen3.5 GatedDeltaNet attention to use it instead of the Python per-token recurrence loop when the op is available. It also tightens local custom-op loading so we no longer implicitly scan repo-local `cmake-out*` directories, and adds coverage for recurrent-state correctness, chunked prefill behavior, and export graph selection. ## What changed - added `llama::recurrent_gated_delta_rule` runtime and AOT registrations - updated Qwen3.5 GatedDeltaNet attention to use the fused op with Python fallback preserved - tightened `custom_ops_aot_lib` discovery: - default to package-local discovery - allow explicit override via `EXECUTORCH_CUSTOM_OPS_AOT_LIB` - removed implicit repo-local `cmake-out*` scanning - added tests for: - recurrent op parity vs reference - `.out` variant behavior - chunked-state parity vs full-sequence execution - custom-op vs fallback attention parity - tiny Qwen3.5 export selecting `llama.recurrent_gated_delta_rule` ## Validation ### Linux CPU-only (aarch64) Built `custom_ops_aot_lib` successfully and loaded it via `EXECUTORCH_CUSTOM_OPS_AOT_LIB`. Passed: - `pytest extension/llm/custom_ops/test_update_cache.py::RecurrentGatedDeltaRuleTest -q` - `3 passed` - `pytest examples/models/llama/tests/test_qwen3_5_attention.py -q` - `7 passed` - `pytest examples/models/llama/tests/test_export_llama_lib.py::ExportLlamaLibTest::test_tiny_qwen35_export_uses_recurrent_gated_delta_rule -q` - `1 passed` ### Real-model CPU validation On a real `Qwen3.5-0.8B` CPU run, fused recurrence matched the fallback path on next-token selection with very small logit drift, and improved eager prefill latency on the tested prompt. Observed on local CPU validation: - same next token from fused path vs fallback - max logit diff on the order of `1e-5` - eager prefill speedup about `1.6x` on the tested prompt ### Windows note A local Windows-only FFHT/MSVC workaround was used during development to keep the local build usable, but that workaround is intentionally **not** included in this PR. ## Non-goals / separate issues I did not treat the local `program.fbs` serialization issue as part of this change. This branch does not modify `exir/_serialize/*` or `schema/program.fbs`, and serialization-focused checks passed on both this branch and clean `main` once the local environment was set up correctly. A separate end-to-end tiny Qwen3.5 `.pte` export probe hit: - `RuntimeError: Missing out variants: {'aten::alias'}` That appears to be a separate pre-existing export issue outside this change set. cc @larryliu0820 @mergennachin @cccclai @helunwencser @jackzhxng --------- Co-authored-by: Digant Desai Co-authored-by: Nikhil Viswanath Sivakumar <68182521+nil-is-all@users.noreply.github.com> --- examples/models/llama/attention.py | 106 ++++- .../llama/tests/test_export_llama_lib.py | 72 ++++ .../llama/tests/test_qwen3_5_attention.py | 105 +++++ .../make_aten_functor_from_et_functor.h | 3 +- extension/llm/custom_ops/custom_ops.py | 176 +++++++- .../op_fast_hadamard_transform_aten.cpp | 33 +- extension/llm/custom_ops/op_sdpa.cpp | 226 ++++++++++ extension/llm/custom_ops/op_sdpa.h | 10 + extension/llm/custom_ops/op_sdpa_aot.cpp | 399 +++++++++++++++--- extension/llm/custom_ops/op_tile_crop_aot.cpp | 39 +- extension/llm/custom_ops/test_update_cache.py | 152 +++++++ 11 files changed, 1227 insertions(+), 94 deletions(-) diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index d6dff173072..7556ef60e19 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -1,3 +1,4 @@ +import logging from abc import ABC, abstractmethod from enum import Enum from typing import Any, Dict, Optional, Tuple, Type, TypedDict @@ -52,6 +53,8 @@ def forward( ATTENTION_REGISTRY: Dict[str, Type[Attention]] = {} +_RECURRENT_GATED_DELTA_RULE_OP = None +_TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = False def register_attention(name: str): @@ -64,6 +67,38 @@ def decorator(cls: Type[Attention]): return decorator +def _get_recurrent_gated_delta_rule_op(): + global _RECURRENT_GATED_DELTA_RULE_OP + global _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP + + if _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP: + return _RECURRENT_GATED_DELTA_RULE_OP + + _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True + try: + _RECURRENT_GATED_DELTA_RULE_OP = ( + torch.ops.llama.recurrent_gated_delta_rule.default + ) + return _RECURRENT_GATED_DELTA_RULE_OP + except (AttributeError, RuntimeError): + pass + + try: + from executorch.extension.llm.custom_ops import custom_ops # noqa: F401 + except (ImportError, OSError, RuntimeError): + logging.debug("Failed to import custom ops library", exc_info=True) + return None + + try: + _RECURRENT_GATED_DELTA_RULE_OP = ( + torch.ops.llama.recurrent_gated_delta_rule.default + ) + except (AttributeError, RuntimeError): + _RECURRENT_GATED_DELTA_RULE_OP = None + + return _RECURRENT_GATED_DELTA_RULE_OP + + class KVCache(nn.Module): def __init__( self, @@ -725,7 +760,7 @@ def _apply_causal_conv(self, mixed_qkv: torch.Tensor) -> torch.Tensor: out = F.silu(out[:, :, -seq_len:]).to(mixed_qkv.dtype) return out.transpose(1, 2).contiguous() - def _recurrent_gated_delta_rule( + def _gated_delta_rule_op( self, query: torch.Tensor, key: torch.Tensor, @@ -733,20 +768,35 @@ def _recurrent_gated_delta_rule( g: torch.Tensor, beta: torch.Tensor, ) -> torch.Tensor: - # query/key/value: (batch, seq_len, num_heads, head_dim) - # g/beta: (batch, seq_len, num_heads) - initial_dtype = query.dtype - query = _l2norm(query, dim=-1, eps=1e-6) - key = _l2norm(key, dim=-1, eps=1e-6) - query, key, value, beta, g = [ - x.transpose(1, 2).contiguous().to(torch.float32) - for x in (query, key, value, beta, g) - ] + batch_size = query.shape[0] + recurrent_gated_delta_rule_op = _get_recurrent_gated_delta_rule_op() + if recurrent_gated_delta_rule_op is not None: + return recurrent_gated_delta_rule_op( + query, + key, + value, + g, + beta, + self.recurrent_state[:batch_size], + ) + return self._naive_gated_delta_rule_op( + query, + key, + value, + g, + beta, + ) - batch_size, num_heads, sequence_length, k_head_dim = key.shape + def _naive_gated_delta_rule_op( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + ) -> torch.Tensor: + batch_size, num_heads, sequence_length, _ = key.shape v_head_dim = value.shape[-1] - scale = 1.0 / (query.shape[-1] ** 0.5) - query = query * scale core_attn_out = torch.zeros( batch_size, @@ -780,6 +830,36 @@ def _recurrent_gated_delta_rule( last_recurrent_state.to(self.recurrent_state.dtype) ) + return core_attn_out + + def _recurrent_gated_delta_rule( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + ) -> torch.Tensor: + # query/key/value: (batch, seq_len, num_heads, head_dim) + # g/beta: (batch, seq_len, num_heads) + initial_dtype = query.dtype + query = _l2norm(query, dim=-1, eps=1e-6) + key = _l2norm(key, dim=-1, eps=1e-6) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) + for x in (query, key, value, beta, g) + ] + + scale = 1.0 / (query.shape[-1] ** 0.5) + query = query * scale + + core_attn_out = self._gated_delta_rule_op( + query, + key, + value, + g, + beta, + ) return core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) def forward( diff --git a/examples/models/llama/tests/test_export_llama_lib.py b/examples/models/llama/tests/test_export_llama_lib.py index 130a55f658c..c96fea8c215 100644 --- a/examples/models/llama/tests/test_export_llama_lib.py +++ b/examples/models/llama/tests/test_export_llama_lib.py @@ -5,7 +5,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import json +import tempfile import unittest +from pathlib import Path from executorch.devtools.backend_debug import get_delegation_info @@ -25,6 +28,7 @@ from executorch.examples.models.llama.export_llama_lib import ( _export_llama, + _prepare_for_llama_export, build_args_parser, get_quantizer_and_quant_params, ) @@ -37,6 +41,39 @@ class ExportLlamaLibTest(unittest.TestCase): + def _make_tiny_qwen35_params(self) -> dict: + return { + "dim": 64, + "hidden_dim": 128, + "n_heads": 4, + "head_dim": 16, + "n_kv_heads": 2, + "n_layers": 4, + "norm_eps": 1e-6, + "rope_theta": 10000000.0, + "use_scaled_rope": False, + "vocab_size": 256, + "use_hf_rope": True, + "partial_rotary_factor": 0.25, + "attention_qkv_bias": False, + "use_qk_norm": True, + "qk_norm_before_rope": True, + "attention_type": "mha", + "use_q_gate": True, + "rms_norm_add_unit_offset": True, + "linear_conv_kernel_dim": 4, + "linear_key_head_dim": 8, + "linear_value_head_dim": 8, + "linear_num_key_heads": 4, + "linear_num_value_heads": 4, + "layer_types": [ + "linear_attention", + "full_attention", + "linear_attention", + "full_attention", + ], + } + def test_has_expected_ops_and_op_counts(self): """ Checks the presence of unwanted expensive ops. @@ -66,6 +103,41 @@ def test_has_expected_ops_and_op_counts(self): for op, _op_info in delegation_info.delegation_by_operator.items(): self.assertTrue(op not in UNWANTED_OPS) + def test_tiny_qwen35_export_uses_recurrent_gated_delta_rule(self): + with tempfile.TemporaryDirectory() as temp_dir: + params_path = Path(temp_dir) / "tiny_qwen35.json" + params_path.write_text(json.dumps(self._make_tiny_qwen35_params())) + + parser = build_args_parser() + args = parser.parse_args( + [ + "--model", + "qwen3_5_0_8b", + "--params", + str(params_path), + "--use_kv_cache", + "--disable_dynamic_shape", + "--max_seq_length", + "8", + "--max_context_length", + "8", + ] + ) + + llm_config = LlmConfig.from_args(args) + builder = _prepare_for_llama_export(llm_config).export() + assert builder.pre_autograd_graph_module is not None + + recurrent_nodes = [ + node + for node in builder.pre_autograd_graph_module.graph.nodes + if "auto_functionalized_v2" in str(node.target) + and node.args + and "llama.recurrent_gated_delta_rule" in str(node.args[0]) + ] + + self.assertEqual(len(recurrent_nodes), 2) + @unittest.skipUnless(HAS_ARM_BACKEND, "ARM backend not available") def test_get_quantizer_and_quant_params_returns_tosa_quantizer(self): llm_config = LlmConfig() diff --git a/examples/models/llama/tests/test_qwen3_5_attention.py b/examples/models/llama/tests/test_qwen3_5_attention.py index 5a9f67d57cf..ba96a96aa43 100644 --- a/examples/models/llama/tests/test_qwen3_5_attention.py +++ b/examples/models/llama/tests/test_qwen3_5_attention.py @@ -6,7 +6,9 @@ import unittest +import executorch.examples.models.llama.attention as attention_module import torch + from executorch.examples.models.llama.attention import ATTENTION_REGISTRY from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.norm import RMSNorm @@ -123,6 +125,109 @@ def test_gated_deltanet_no_input_pos_does_not_leak_state(self): torch.allclose(state_after_first, state_after_second, atol=1e-5) ) + def test_gated_deltanet_chunked_prefill_matches_full_sequence(self): + torch.manual_seed(0) + args = self._make_args( + use_kv_cache=True, + use_q_gate=True, + linear_conv_kernel_dim=4, + linear_key_head_dim=4, + linear_value_head_dim=4, + linear_num_key_heads=2, + linear_num_value_heads=4, + ) + rope = Rope(args) + attn_full = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope) + attn_chunked = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope) + attn_chunked.load_state_dict(attn_full.state_dict()) + + x = torch.randn(1, 5, args.dim) + dummy_freq = torch.zeros(1, 1) + + full_output, _ = attn_full( + x, + dummy_freq, + dummy_freq, + input_pos=torch.tensor([0], dtype=torch.long), + ) + + chunk_outputs = [] + for start, end in ((0, 3), (3, 4), (4, 5)): + output, _ = attn_chunked( + x[:, start:end], + dummy_freq, + dummy_freq, + input_pos=torch.tensor([start], dtype=torch.long), + ) + chunk_outputs.append(output) + + chunked_output = torch.cat(chunk_outputs, dim=1) + + self.assertTrue(torch.allclose(chunked_output, full_output, atol=1e-5)) + self.assertTrue( + torch.allclose( + attn_chunked.recurrent_state, attn_full.recurrent_state, atol=1e-5 + ) + ) + self.assertTrue( + torch.allclose(attn_chunked.conv_state, attn_full.conv_state, atol=1e-5) + ) + + def test_gated_deltanet_custom_op_matches_fallback(self): + recurrent_op = attention_module._get_recurrent_gated_delta_rule_op() + if recurrent_op is None: + self.skipTest("llama::recurrent_gated_delta_rule is not available") + + torch.manual_seed(0) + args = self._make_args( + use_kv_cache=True, + use_q_gate=True, + linear_conv_kernel_dim=4, + linear_key_head_dim=4, + linear_value_head_dim=4, + linear_num_key_heads=2, + linear_num_value_heads=4, + ) + rope = Rope(args) + attn_custom = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope) + attn_fallback = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope) + attn_fallback.load_state_dict(attn_custom.state_dict()) + + query = torch.randn(1, 3, attn_custom.num_v_heads, attn_custom.head_k_dim) + key = torch.randn(1, 3, attn_custom.num_v_heads, attn_custom.head_k_dim) + value = torch.randn(1, 3, attn_custom.num_v_heads, attn_custom.head_v_dim) + g = torch.randn(1, 3, attn_custom.num_v_heads) + beta = torch.sigmoid(torch.randn(1, 3, attn_custom.num_v_heads)) + + original_op = attention_module._RECURRENT_GATED_DELTA_RULE_OP + original_tried_loading = ( + attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP + ) + try: + attention_module._RECURRENT_GATED_DELTA_RULE_OP = recurrent_op + attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True + custom_output = attn_custom._recurrent_gated_delta_rule( + query, key, value, g, beta + ) + + attention_module._RECURRENT_GATED_DELTA_RULE_OP = None + attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True + fallback_output = attn_fallback._recurrent_gated_delta_rule( + query, key, value, g, beta + ) + finally: + attention_module._RECURRENT_GATED_DELTA_RULE_OP = original_op + attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = ( + original_tried_loading + ) + + self.assertTrue(torch.allclose(custom_output, fallback_output, atol=1e-5)) + self.assertTrue( + torch.allclose( + attn_custom.recurrent_state, attn_fallback.recurrent_state, atol=1e-5 + ) + ) + if __name__ == "__main__": unittest.main() diff --git a/extension/aten_util/make_aten_functor_from_et_functor.h b/extension/aten_util/make_aten_functor_from_et_functor.h index 8e1c2bf0143..67e7344330e 100644 --- a/extension/aten_util/make_aten_functor_from_et_functor.h +++ b/extension/aten_util/make_aten_functor_from_et_functor.h @@ -15,7 +15,8 @@ #pragma once #include #include -#if __cplusplus < 201703L +#if (defined(_MSC_VER) && (!defined(_MSVC_LANG) || _MSVC_LANG < 201703L)) || \ + (!defined(_MSC_VER) && __cplusplus < 201703L) #error "This header requires C++17" #endif #include diff --git a/extension/llm/custom_ops/custom_ops.py b/extension/llm/custom_ops/custom_ops.py index 9aacded4b4c..e0b009d7a13 100644 --- a/extension/llm/custom_ops/custom_ops.py +++ b/extension/llm/custom_ops/custom_ops.py @@ -11,7 +11,9 @@ # pyre-unsafe import logging +import os +from pathlib import Path from typing import Tuple import torch @@ -21,33 +23,84 @@ from torch.library import impl aten = torch.ops.aten +_CUSTOM_OPS_DLL_DIR_HANDLES = [] -try: - op = torch.ops.llama.sdpa_with_kv_cache.default - assert op is not None - op2 = torch.ops.llama.fast_hadamard_transform.default - assert op2 is not None -except: - # This is needed to ensure that custom ops are registered - from executorch.extension.pybindings import portable_lib # noqa # usort: skip - # Ideally package is installed in only one location but usage of - # PYATHONPATH can result in multiple locations. - # ATM this is mainly used in CI for qnn runner. Will need to revisit this - from pathlib import Path +def _is_custom_ops_registered() -> bool: + try: + torch.ops.llama.sdpa_with_kv_cache.default + torch.ops.llama.fast_hadamard_transform.default + return True + except (AttributeError, RuntimeError): + return False + + +def _get_custom_ops_library_override() -> Path | None: + override = os.environ.get("EXECUTORCH_CUSTOM_OPS_AOT_LIB") + if override is None: + return None + + lib_path = Path(override).expanduser().resolve() + if not lib_path.is_file(): + raise FileNotFoundError( + "EXECUTORCH_CUSTOM_OPS_AOT_LIB must point to an existing " + f"custom_ops_aot_lib, but got {lib_path}" + ) + return lib_path + + +def _find_custom_ops_library() -> Path: + override = _get_custom_ops_library_override() + if override is not None: + return override package_path = Path(__file__).parent.resolve() - logging.info(f"Looking for libcustom_ops_aot_lib.so in {package_path}") + candidates = [] + patterns = ( + "**/custom_ops_aot_lib.dll", + "**/libcustom_ops_aot_lib.so", + "**/libcustom_ops_aot_lib.dylib", + ) + + for pattern in patterns: + candidates.extend(package_path.glob(pattern)) + + libs = sorted({path.resolve() for path in candidates if path.is_file()}) + if not libs: + raise FileNotFoundError( + f"Could not find custom_ops_aot_lib under {package_path}" + ) + return max(libs, key=lambda path: path.stat().st_mtime) + + +def _load_custom_ops_library() -> None: + try: + # This is needed to ensure that custom ops are registered when + # portable_lib is available in the current environment. + from executorch.extension.pybindings import portable_lib # noqa # usort: skip + except ImportError: + portable_lib = None + + lib_path = _find_custom_ops_library() + logging.info(f"Loading custom ops library: {lib_path}") + + if os.name == "nt": + _CUSTOM_OPS_DLL_DIR_HANDLES.append(os.add_dll_directory(str(lib_path.parent))) + torch_lib_dir = Path(torch.__file__).resolve().parent / "lib" + if torch_lib_dir.is_dir(): + _CUSTOM_OPS_DLL_DIR_HANDLES.append(os.add_dll_directory(str(torch_lib_dir))) - libs = list(package_path.glob("**/*custom_ops_aot_lib.*")) + torch.ops.load_library(lib_path) - assert len(libs) == 1, f"Expected 1 library but got {len(libs)}" - logging.info(f"Loading custom ops library: {libs[0]}") - torch.ops.load_library(libs[0]) - op = torch.ops.llama.sdpa_with_kv_cache.default - assert op is not None - op2 = torch.ops.llama.fast_hadamard_transform.default - assert op2 is not None + # Keep the import alive to avoid lint complaints in environments where + # portable_lib is needed for symbol resolution. + _ = portable_lib + + +if not _is_custom_ops_registered(): + _load_custom_ops_library() + if not _is_custom_ops_registered(): + raise RuntimeError("Failed to register ExecuTorch custom ops library") custom_ops_lib = torch.library.Library("llama", "IMPL") @@ -271,6 +324,87 @@ def update_cache_with_indices_meta( return torch.empty((1,), dtype=value.dtype, device="meta") +def _validate_recurrent_gated_delta_rule_params( + query, + key, + value, + g, + beta, + recurrent_state, +): + assert ( + query.dim() == 4 + ), f"Expected query to be 4 dimensional but got {query.dim()} dimensions." + assert ( + key.dim() == 4 + ), f"Expected key to be 4 dimensional but got {key.dim()} dimensions." + assert ( + value.dim() == 4 + ), f"Expected value to be 4 dimensional but got {value.dim()} dimensions." + assert g.dim() == 3, f"Expected g to be 3 dimensional but got {g.dim()} dimensions." + assert ( + beta.dim() == 3 + ), f"Expected beta to be 3 dimensional but got {beta.dim()} dimensions." + assert ( + recurrent_state.dim() == 4 + ), f"Expected recurrent_state to be 4 dimensional but got {recurrent_state.dim()} dimensions." + + for name, tensor in { + "query": query, + "key": key, + "value": value, + "g": g, + "beta": beta, + "recurrent_state": recurrent_state, + }.items(): + assert ( + tensor.dtype == torch.float32 + ), f"Expected {name} to be float32 but got {tensor.dtype}" + + assert ( + query.shape == key.shape + ), f"Expected query and key to have matching shapes but got {query.shape} and {key.shape}" + assert ( + query.shape[:3] == value.shape[:3] + ), f"Expected query and value to match in batch/head/sequence dims but got {query.shape} and {value.shape}" + assert ( + g.shape == query.shape[:3] + ), f"Expected g to match query batch/head/sequence dims but got {g.shape} and {query.shape}" + assert ( + beta.shape == query.shape[:3] + ), f"Expected beta to match query batch/head/sequence dims but got {beta.shape} and {query.shape}" + assert recurrent_state.shape == ( + query.size(0), + query.size(1), + query.size(3), + value.size(3), + ), ( + "Expected recurrent_state to have shape " + f"{(query.size(0), query.size(1), query.size(3), value.size(3))} " + f"but got {recurrent_state.shape}" + ) + + +@impl(custom_ops_lib, "recurrent_gated_delta_rule", "Meta") +def recurrent_gated_delta_rule_meta( + query, + key, + value, + g, + beta, + recurrent_state, +): + _validate_recurrent_gated_delta_rule_params( + query, + key, + value, + g, + beta, + recurrent_state, + ) + return torch.empty_like(value) + + def _validate_quantized_sdpa_params( query, key, diff --git a/extension/llm/custom_ops/op_fast_hadamard_transform_aten.cpp b/extension/llm/custom_ops/op_fast_hadamard_transform_aten.cpp index 146ac3cc298..d48f593868c 100644 --- a/extension/llm/custom_ops/op_fast_hadamard_transform_aten.cpp +++ b/extension/llm/custom_ops/op_fast_hadamard_transform_aten.cpp @@ -13,14 +13,40 @@ namespace torch::executor::native { namespace { +template +auto to_et_arg(AType&& value) { + return executorch::extension::internal::type_convert( + std::forward(value)); +} + +at::Tensor& copy_et_result_to_out(Tensor& et_result, at::Tensor& out) { + auto converted_result = + executorch::extension::internal::type_convert( + et_result) + .call(); + at::native::resize_output(out, converted_result.sizes()); + out.copy_(converted_result); + return out; +} + Tensor& fast_hadamard_transform_out_no_context(const Tensor& vec, Tensor& out) { executorch::aten::RuntimeContext context; return fast_hadamard_transform_out(context, vec, out); } + +at::Tensor& fast_hadamard_transform_out_aten( + const at::Tensor& vec, + at::Tensor& out) { + auto vec_et = to_et_arg(vec); + auto out_et = to_et_arg(out); + auto& et_result = + fast_hadamard_transform_out_no_context(vec_et.call(), out_et.call()); + return copy_et_result_to_out(et_result, out); +} + at::Tensor fast_hadamard_transform_aten(const at::Tensor& vec) { auto out = at::empty_like(vec); - WRAP_TO_ATEN(fast_hadamard_transform_out_no_context, 1) - (vec, out); + fast_hadamard_transform_out_aten(vec, out); return out; } } // namespace @@ -38,6 +64,5 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { torch::executor::native::fast_hadamard_transform_aten); m.impl( "fast_hadamard_transform.out", - WRAP_TO_ATEN( - torch::executor::native::fast_hadamard_transform_out_no_context, 1)); + torch::executor::native::fast_hadamard_transform_out_aten); } diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index 72bddce7b5b..76ee9cb915f 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -15,6 +15,10 @@ #include // @lint-ignore CLANGTIDY facebook-unused-include-check #include +#include +#include +#include +#include #ifdef ET_USE_THREADPOOL #include @@ -178,6 +182,68 @@ bool validate_cache_params( return true; } +bool validate_recurrent_gated_delta_rule_args( + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& g, + const Tensor& beta, + const Tensor& recurrent_state) { + ET_CHECK_OR_RETURN_FALSE(query.dim() == 4, "query must be a 4D tensor"); + ET_CHECK_OR_RETURN_FALSE(key.dim() == 4, "key must be a 4D tensor"); + ET_CHECK_OR_RETURN_FALSE(value.dim() == 4, "value must be a 4D tensor"); + ET_CHECK_OR_RETURN_FALSE(g.dim() == 3, "g must be a 3D tensor"); + ET_CHECK_OR_RETURN_FALSE(beta.dim() == 3, "beta must be a 3D tensor"); + ET_CHECK_OR_RETURN_FALSE( + recurrent_state.dim() == 4, "recurrent_state must be a 4D tensor"); + + ET_CHECK_OR_RETURN_FALSE( + query.scalar_type() == ScalarType::Float, "query must be float32"); + ET_CHECK_OR_RETURN_FALSE( + key.scalar_type() == ScalarType::Float, "key must be float32"); + ET_CHECK_OR_RETURN_FALSE( + value.scalar_type() == ScalarType::Float, "value must be float32"); + ET_CHECK_OR_RETURN_FALSE( + g.scalar_type() == ScalarType::Float, "g must be float32"); + ET_CHECK_OR_RETURN_FALSE( + beta.scalar_type() == ScalarType::Float, "beta must be float32"); + ET_CHECK_OR_RETURN_FALSE( + recurrent_state.scalar_type() == ScalarType::Float, + "recurrent_state must be float32"); + + ET_CHECK_OR_RETURN_FALSE( + query.size(0) == key.size(0) && query.size(1) == key.size(1) && + query.size(2) == key.size(2) && query.size(3) == key.size(3), + "query and key must have matching shapes"); + ET_CHECK_OR_RETURN_FALSE( + query.size(0) == value.size(0) && query.size(1) == value.size(1) && + query.size(2) == value.size(2), + "query and value must match in batch/head/sequence dims"); + ET_CHECK_OR_RETURN_FALSE( + g.size(0) == query.size(0) && g.size(1) == query.size(1) && + g.size(2) == query.size(2), + "g must match query batch/head/sequence dims"); + ET_CHECK_OR_RETURN_FALSE( + beta.size(0) == query.size(0) && beta.size(1) == query.size(1) && + beta.size(2) == query.size(2), + "beta must match query batch/head/sequence dims"); + ET_CHECK_OR_RETURN_FALSE( + recurrent_state.size(0) == query.size(0) && + recurrent_state.size(1) == query.size(1) && + recurrent_state.size(2) == query.size(3) && + recurrent_state.size(3) == value.size(3), + "recurrent_state shape must match [B, H, K, V]"); + + for (const Tensor* tensor : + {&query, &key, &value, &g, &beta, &recurrent_state}) { + ET_CHECK_OR_RETURN_FALSE( + is_contiguous_dim_order((*tensor).dim_order().data(), (*tensor).dim()), + "recurrent gated delta rule expects contiguous inputs"); + } + + return true; +} + // TODO: seq_length is not yet used for copy void update_cache( const Tensor& projected_value, @@ -610,6 +676,133 @@ Tensor& sdpa_with_kv_cache_out( return output; } + +Tensor& recurrent_gated_delta_rule_out( + RuntimeContext& ctx, + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& g, + const Tensor& beta, + Tensor& recurrent_state, + Tensor& output) { + ET_KERNEL_CHECK_MSG( + ctx, + resize_tensor(output, value.sizes()) == Error::Ok, + InvalidArgument, + output, + "Failed to resize recurrent_gated_delta_rule output tensor."); + ET_KERNEL_CHECK( + ctx, + validate_recurrent_gated_delta_rule_args( + query, key, value, g, beta, recurrent_state), + InvalidArgument, + output); + ET_KERNEL_CHECK( + ctx, output.scalar_type() == ScalarType::Float, InvalidArgument, output); + ET_KERNEL_CHECK( + ctx, + is_contiguous_dim_order(output.dim_order().data(), output.dim()), + InvalidArgument, + output); + + const auto batch_size = query.size(0); + const auto num_heads = query.size(1); + const auto sequence_length = query.size(2); + const auto k_head_dim = query.size(3); + const auto v_head_dim = value.size(3); + + const auto q_batch_stride = num_heads * sequence_length * k_head_dim; + const auto q_head_stride = sequence_length * k_head_dim; + const auto q_seq_stride = k_head_dim; + + const auto value_batch_stride = num_heads * sequence_length * v_head_dim; + const auto value_head_stride = sequence_length * v_head_dim; + const auto value_seq_stride = v_head_dim; + + const auto gv_batch_stride = num_heads * sequence_length; + const auto gv_head_stride = sequence_length; + + const auto state_batch_stride = num_heads * k_head_dim * v_head_dim; + const auto state_head_stride = k_head_dim * v_head_dim; + + const auto* query_data = query.const_data_ptr(); + const auto* key_data = key.const_data_ptr(); + const auto* value_data = value.const_data_ptr(); + const auto* g_data = g.const_data_ptr(); + const auto* beta_data = beta.const_data_ptr(); + auto* recurrent_state_data = recurrent_state.mutable_data_ptr(); + auto* output_data = output.mutable_data_ptr(); + std::vector kv_mem(v_head_dim); + std::vector delta(v_head_dim); + + for (int64_t batch = 0; batch < batch_size; ++batch) { + for (int64_t head = 0; head < num_heads; ++head) { + const auto q_offset = batch * q_batch_stride + head * q_head_stride; + const auto value_offset = + batch * value_batch_stride + head * value_head_stride; + const auto gv_offset = batch * gv_batch_stride + head * gv_head_stride; + const auto state_offset = + batch * state_batch_stride + head * state_head_stride; + + const auto* q_head = query_data + q_offset; + const auto* k_head = key_data + q_offset; + const auto* value_head = value_data + value_offset; + const auto* g_head = g_data + gv_offset; + const auto* beta_head = beta_data + gv_offset; + auto* state_head = recurrent_state_data + state_offset; + auto* output_head = output_data + value_offset; + + for (int64_t token = 0; token < sequence_length; ++token) { + const auto* q_t = q_head + token * q_seq_stride; + const auto* k_t = k_head + token * q_seq_stride; + const auto* v_t = value_head + token * value_seq_stride; + auto* output_t = output_head + token * value_seq_stride; + + const float g_t = std::exp(g_head[token]); + const float beta_t = beta_head[token]; + + if (g_t != 1.0f) { + for (int64_t idx = 0; idx < state_head_stride; ++idx) { + state_head[idx] *= g_t; + } + } + + std::fill(kv_mem.begin(), kv_mem.end(), 0.0f); + for (int64_t k_idx = 0; k_idx < k_head_dim; ++k_idx) { + const float key_value = k_t[k_idx]; + const auto* state_row = state_head + k_idx * v_head_dim; + for (int64_t v_idx = 0; v_idx < v_head_dim; ++v_idx) { + kv_mem[v_idx] += state_row[v_idx] * key_value; + } + } + + for (int64_t v_idx = 0; v_idx < v_head_dim; ++v_idx) { + delta[v_idx] = (v_t[v_idx] - kv_mem[v_idx]) * beta_t; + } + + for (int64_t k_idx = 0; k_idx < k_head_dim; ++k_idx) { + const float key_value = k_t[k_idx]; + auto* state_row = state_head + k_idx * v_head_dim; + for (int64_t v_idx = 0; v_idx < v_head_dim; ++v_idx) { + state_row[v_idx] += key_value * delta[v_idx]; + } + } + + std::fill(output_t, output_t + v_head_dim, 0.0f); + for (int64_t k_idx = 0; k_idx < k_head_dim; ++k_idx) { + const float query_value = q_t[k_idx]; + const auto* state_row = state_head + k_idx * v_head_dim; + for (int64_t v_idx = 0; v_idx < v_head_dim; ++v_idx) { + output_t[v_idx] += state_row[v_idx] * query_value; + } + } + } + } + } + + return output; +} } // namespace native } // namespace executor } // namespace torch @@ -628,3 +821,36 @@ EXECUTORCH_LIBRARY( llama, "custom_quantized_sdpa.out", torch::executor::native::custom_quantized_sdpa_out); + +namespace { + +void recurrent_gated_delta_rule_out_boxed( + executorch::runtime::KernelRuntimeContext& ctx, + executorch::runtime::Span stack) { + ET_KERNEL_CHECK_MSG( + ctx, + stack.size() == 7, + InvalidProgram, + /* void */, + "Expected %zu args, got %zu", + static_cast(7), + stack.size()); + + auto& query = stack[0]->toTensor(); + auto& key = stack[1]->toTensor(); + auto& value = stack[2]->toTensor(); + auto& g = stack[3]->toTensor(); + auto& beta = stack[4]->toTensor(); + auto& recurrent_state = stack[5]->toTensor(); + auto& output = stack[6]->toTensor(); + + (void)torch::executor::native::recurrent_gated_delta_rule_out( + ctx, query, key, value, g, beta, recurrent_state, output); +} + +const auto recurrent_gated_delta_rule_out_registration = + executorch::runtime::register_kernel(executorch::runtime::Kernel( + "llama::recurrent_gated_delta_rule.out", + recurrent_gated_delta_rule_out_boxed)); + +} // namespace diff --git a/extension/llm/custom_ops/op_sdpa.h b/extension/llm/custom_ops/op_sdpa.h index 9d357eb6ea1..9f029f52f31 100644 --- a/extension/llm/custom_ops/op_sdpa.h +++ b/extension/llm/custom_ops/op_sdpa.h @@ -75,6 +75,16 @@ Tensor& custom_quantized_sdpa_out( const optional& v_scales, const bool is_seq_at_dim_1, Tensor& output); + +Tensor& recurrent_gated_delta_rule_out( + RuntimeContext& ctx, + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& g, + const Tensor& beta, + Tensor& recurrent_state, + Tensor& output); } // namespace native } // namespace executor } // namespace torch diff --git a/extension/llm/custom_ops/op_sdpa_aot.cpp b/extension/llm/custom_ops/op_sdpa_aot.cpp index 5bbf22d336e..d4d1122f614 100644 --- a/extension/llm/custom_ops/op_sdpa_aot.cpp +++ b/extension/llm/custom_ops/op_sdpa_aot.cpp @@ -17,6 +17,24 @@ namespace torch { namespace executor { namespace native { +namespace { +template +auto to_et_arg(AType&& value) { + return executorch::extension::internal::type_convert( + std::forward(value)); +} + +at::Tensor& copy_et_result_to_out(Tensor& et_result, at::Tensor& out) { + auto converted_result = + executorch::extension::internal::type_convert( + et_result) + .call(); + at::native::resize_output(out, converted_result.sizes()); + out.copy_(converted_result); + return out; +} +} // namespace + Tensor& sdpa_with_kv_cache_out_no_context( const Tensor& q_projected, const Tensor& k_projected, @@ -50,6 +68,20 @@ at::Tensor sdpa_with_kv_cache_aten( // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional scale); +at::Tensor& sdpa_with_kv_cache_out_aten( + const at::Tensor& q_projected, + const at::Tensor& k_projected, + const at::Tensor& v_projected, + at::Tensor& key_cache, + at::Tensor& value_cache, + const int64_t start_pos, + const int64_t seq_len, + const std::optional attn_mask, + const double dropout_p, + const bool is_causal, + const std::optional scale, + at::Tensor& output); + Tensor& custom_sdpa_out_no_context( const Tensor& q, const Tensor& k, @@ -77,6 +109,17 @@ at::Tensor custom_sdpa_aten( // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional scale); +at::Tensor& custom_sdpa_out_aten( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const int64_t start_pos, + const std::optional attn_mask, + const double dropout_p, + const bool is_causal, + const std::optional scale, + at::Tensor& output); + Tensor& custom_quantized_sdpa_out_no_context( const Tensor& q, const Tensor& k, @@ -118,6 +161,24 @@ at::Tensor custom_quantized_sdpa_aten( const std::optional& v_scales, const bool is_seq_at_dim_2); +at::Tensor& custom_quantized_sdpa_out_aten( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const int64_t start_pos, + const std::optional attn_mask, + const double dropout_p, + const bool is_causal, + const std::optional scale, + const std::optional& q_zero_points, + const std::optional& q_scales, + const std::optional& k_zero_points, + const std::optional& k_scales, + const std::optional& v_zero_points, + const std::optional& v_scales, + const bool is_seq_at_dim_2, + at::Tensor& output); + Tensor& update_cache_out_no_context( const Tensor& value, Tensor& cache, @@ -129,6 +190,12 @@ at::Tensor update_cache_aten( at::Tensor& cache, const int64_t start_pos); +at::Tensor& update_cache_out_aten( + const at::Tensor& value, + at::Tensor& cache, + const int64_t start_pos, + at::Tensor& output); + // New functions for update_cache_with_indices Tensor& update_cache_with_indices_out_no_context( const Tensor& value, @@ -143,6 +210,39 @@ at::Tensor update_cache_with_indices_aten( const int64_t start_pos, const at::Tensor& indices); +at::Tensor& update_cache_with_indices_out_aten( + const at::Tensor& value, + at::Tensor& cache, + const int64_t start_pos, + const at::Tensor& indices, + at::Tensor& output); + +Tensor& recurrent_gated_delta_rule_out_no_context( + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& g, + const Tensor& beta, + Tensor& recurrent_state, + Tensor& output); + +at::Tensor recurrent_gated_delta_rule_aten( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& g, + const at::Tensor& beta, + at::Tensor& recurrent_state); + +at::Tensor& recurrent_gated_delta_rule_out_aten( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& g, + const at::Tensor& beta, + at::Tensor& recurrent_state, + at::Tensor& output); + Tensor& sdpa_with_kv_cache_out_no_context( const Tensor& q_projected, const Tensor& k_projected, @@ -192,22 +292,59 @@ at::Tensor sdpa_with_kv_cache_aten( // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional scale) { auto output = at::empty_like(q_projected); - WRAP_TO_ATEN(sdpa_with_kv_cache_out_no_context, 11) - (q_projected, - k_projected, - v_projected, - key_cache, - value_cache, - start_pos, - seq_len, - attn_mask, - dropout_p, - is_causal, - scale, - output); + sdpa_with_kv_cache_out_aten( + q_projected, + k_projected, + v_projected, + key_cache, + value_cache, + start_pos, + seq_len, + attn_mask, + dropout_p, + is_causal, + scale, + output); return output; } +at::Tensor& sdpa_with_kv_cache_out_aten( + const at::Tensor& q_projected, + const at::Tensor& k_projected, + const at::Tensor& v_projected, + at::Tensor& key_cache, + at::Tensor& value_cache, + const int64_t start_pos, + const int64_t seq_len, + const std::optional attn_mask, + const double dropout_p, + const bool is_causal, + const std::optional scale, + at::Tensor& output) { + auto q_et = to_et_arg(q_projected); + auto k_et = to_et_arg(k_projected); + auto v_et = to_et_arg(v_projected); + auto key_cache_et = to_et_arg(key_cache); + auto value_cache_et = to_et_arg(value_cache); + auto attn_mask_et = to_et_arg>(attn_mask); + auto scale_et = to_et_arg>(scale); + auto output_et = to_et_arg(output); + auto& et_result = sdpa_with_kv_cache_out_no_context( + q_et.call(), + k_et.call(), + v_et.call(), + key_cache_et.call(), + value_cache_et.call(), + start_pos, + seq_len, + attn_mask_et.call(), + dropout_p, + is_causal, + scale_et.call(), + output_et.call()); + return copy_et_result_to_out(et_result, output); +} + Tensor& custom_sdpa_out_no_context( const Tensor& q, const Tensor& k, @@ -248,11 +385,40 @@ at::Tensor custom_sdpa_aten( // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional scale) { auto output = at::empty(q.sizes()); - WRAP_TO_ATEN(custom_sdpa_out_no_context, 8) - (q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output); + custom_sdpa_out_aten( + q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output); return output; } +at::Tensor& custom_sdpa_out_aten( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const int64_t start_pos, + const std::optional attn_mask, + const double dropout_p, + const bool is_causal, + const std::optional scale, + at::Tensor& output) { + auto q_et = to_et_arg(q); + auto k_et = to_et_arg(k); + auto v_et = to_et_arg(v); + auto attn_mask_et = to_et_arg>(attn_mask); + auto scale_et = to_et_arg>(scale); + auto output_et = to_et_arg(output); + auto& et_result = custom_sdpa_out_no_context( + q_et.call(), + k_et.call(), + v_et.call(), + start_pos, + attn_mask_et.call(), + dropout_p, + is_causal, + scale_et.call(), + output_et.call()); + return copy_et_result_to_out(et_result, output); +} + Tensor& custom_quantized_sdpa_out_no_context( const Tensor& q, const Tensor& k, @@ -314,26 +480,75 @@ at::Tensor custom_quantized_sdpa_aten( const std::optional& v_scales, const bool is_seq_at_dim_2) { auto output = at::empty(q.sizes()); - WRAP_TO_ATEN(custom_quantized_sdpa_out_no_context, 15) - (q, - k, - v, - start_pos, - attn_mask, - dropout_p, - is_causal, - scale, - q_zero_points, - q_scales, - k_zero_points, - k_scales, - v_zero_points, - v_scales, - is_seq_at_dim_2, - output); + custom_quantized_sdpa_out_aten( + q, + k, + v, + start_pos, + attn_mask, + dropout_p, + is_causal, + scale, + q_zero_points, + q_scales, + k_zero_points, + k_scales, + v_zero_points, + v_scales, + is_seq_at_dim_2, + output); return output; } +at::Tensor& custom_quantized_sdpa_out_aten( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const int64_t start_pos, + const std::optional attn_mask, + const double dropout_p, + const bool is_causal, + const std::optional scale, + const std::optional& q_zero_points, + const std::optional& q_scales, + const std::optional& k_zero_points, + const std::optional& k_scales, + const std::optional& v_zero_points, + const std::optional& v_scales, + const bool is_seq_at_dim_2, + at::Tensor& output) { + auto q_et = to_et_arg(q); + auto k_et = to_et_arg(k); + auto v_et = to_et_arg(v); + auto attn_mask_et = to_et_arg>(attn_mask); + auto scale_et = to_et_arg>(scale); + auto q_zero_points_et = to_et_arg>(q_zero_points); + auto q_scales_et = to_et_arg>(q_scales); + auto k_zero_points_et = to_et_arg>(k_zero_points); + auto k_scales_et = to_et_arg>(k_scales); + auto v_zero_points_et = to_et_arg>(v_zero_points); + auto v_scales_et = to_et_arg>(v_scales); + auto output_et = to_et_arg(output); + auto& et_result = custom_quantized_sdpa_out_no_context( + q_et.call(), + k_et.call(), + v_et.call(), + start_pos, + attn_mask_et.call(), + dropout_p, + is_causal, + scale_et.call(), + q_zero_points_et.call(), + q_scales_et.call(), + k_zero_points_et.call(), + k_scales_et.call(), + v_zero_points_et.call(), + v_scales_et.call(), + is_seq_at_dim_2, + output_et.call()); + return copy_et_result_to_out(et_result, output); +} + Tensor& update_cache_out_no_context( const Tensor& value, Tensor& cache, @@ -349,11 +564,23 @@ at::Tensor update_cache_aten( at::Tensor& cache, const int64_t start_pos) { auto output = at::empty({1}); - WRAP_TO_ATEN(update_cache_out_no_context, 3) - (value, cache, start_pos, output); + update_cache_out_aten(value, cache, start_pos, output); return output; } +at::Tensor& update_cache_out_aten( + const at::Tensor& value, + at::Tensor& cache, + const int64_t start_pos, + at::Tensor& output) { + auto value_et = to_et_arg(value); + auto cache_et = to_et_arg(cache); + auto output_et = to_et_arg(output); + auto& et_result = update_cache_out_no_context( + value_et.call(), cache_et.call(), start_pos, output_et.call()); + return copy_et_result_to_out(et_result, output); +} + // Implementations for update_cache_with_indices Tensor& update_cache_with_indices_out_no_context( const Tensor& value, @@ -372,11 +599,81 @@ at::Tensor update_cache_with_indices_aten( const int64_t start_pos, const at::Tensor& indices) { auto output = at::empty({1}); - WRAP_TO_ATEN(update_cache_with_indices_out_no_context, 4) - (value, cache, start_pos, indices, output); + update_cache_with_indices_out_aten(value, cache, start_pos, indices, output); return output; } +at::Tensor& update_cache_with_indices_out_aten( + const at::Tensor& value, + at::Tensor& cache, + const int64_t start_pos, + const at::Tensor& indices, + at::Tensor& output) { + auto value_et = to_et_arg(value); + auto cache_et = to_et_arg(cache); + auto indices_et = to_et_arg(indices); + auto output_et = to_et_arg(output); + auto& et_result = update_cache_with_indices_out_no_context( + value_et.call(), + cache_et.call(), + start_pos, + indices_et.call(), + output_et.call()); + return copy_et_result_to_out(et_result, output); +} + +Tensor& recurrent_gated_delta_rule_out_no_context( + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& g, + const Tensor& beta, + Tensor& recurrent_state, + Tensor& output) { + executorch::aten::RuntimeContext context{}; + return torch::executor::native::recurrent_gated_delta_rule_out( + context, query, key, value, g, beta, recurrent_state, output); +} + +at::Tensor recurrent_gated_delta_rule_aten( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& g, + const at::Tensor& beta, + at::Tensor& recurrent_state) { + auto output = at::empty_like(value); + recurrent_gated_delta_rule_out_aten( + query, key, value, g, beta, recurrent_state, output); + return output; +} + +at::Tensor& recurrent_gated_delta_rule_out_aten( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& g, + const at::Tensor& beta, + at::Tensor& recurrent_state, + at::Tensor& output) { + auto query_et = to_et_arg(query); + auto key_et = to_et_arg(key); + auto value_et = to_et_arg(value); + auto g_et = to_et_arg(g); + auto beta_et = to_et_arg(beta); + auto recurrent_state_et = to_et_arg(recurrent_state); + auto output_et = to_et_arg(output); + auto& et_result = recurrent_gated_delta_rule_out_no_context( + query_et.call(), + key_et.call(), + value_et.call(), + g_et.call(), + beta_et.call(), + recurrent_state_et.call(), + output_et.call()); + return copy_et_result_to_out(et_result, output); +} + } // namespace native } // namespace executor } // namespace torch @@ -410,6 +707,12 @@ TORCH_LIBRARY_FRAGMENT(llama, m) { m.def( "update_cache_with_indices.out(Tensor value, Tensor(a!) cache, " "SymInt start_pos, Tensor indices, *, Tensor(b!) out) -> Tensor(b!)"); + m.def( + "recurrent_gated_delta_rule(Tensor query, Tensor key, Tensor value, Tensor g, " + "Tensor beta, Tensor(a!) recurrent_state) -> Tensor"); + m.def( + "recurrent_gated_delta_rule.out(Tensor query, Tensor key, Tensor value, Tensor g, " + "Tensor beta, Tensor(a!) recurrent_state, *, Tensor(b!) out) -> Tensor(b!)"); m.def( "custom_quantized_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, " "Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, " @@ -430,29 +733,27 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { "sdpa_with_kv_cache", torch::executor::native::sdpa_with_kv_cache_aten); m.impl( "sdpa_with_kv_cache.out", - WRAP_TO_ATEN( - torch::executor::native::sdpa_with_kv_cache_out_no_context, 11)); + torch::executor::native::sdpa_with_kv_cache_out_aten); m.impl("custom_sdpa", torch::executor::native::custom_sdpa_aten); - m.impl( - "custom_sdpa.out", - WRAP_TO_ATEN(torch::executor::native::custom_sdpa_out_no_context, 8)); + m.impl("custom_sdpa.out", torch::executor::native::custom_sdpa_out_aten); m.impl("update_cache", torch::executor::native::update_cache_aten); - m.impl( - "update_cache.out", - WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 3)); + m.impl("update_cache.out", torch::executor::native::update_cache_out_aten); m.impl( "update_cache_with_indices", torch::executor::native::update_cache_with_indices_aten); m.impl( "update_cache_with_indices.out", - WRAP_TO_ATEN( - torch::executor::native::update_cache_with_indices_out_no_context, - 4)); + torch::executor::native::update_cache_with_indices_out_aten); + m.impl( + "recurrent_gated_delta_rule", + torch::executor::native::recurrent_gated_delta_rule_aten); + m.impl( + "recurrent_gated_delta_rule.out", + torch::executor::native::recurrent_gated_delta_rule_out_aten); m.impl( "custom_quantized_sdpa", torch::executor::native::custom_quantized_sdpa_aten); m.impl( "custom_quantized_sdpa.out", - WRAP_TO_ATEN( - torch::executor::native::custom_quantized_sdpa_out_no_context, 15)); + torch::executor::native::custom_quantized_sdpa_out_aten); } diff --git a/extension/llm/custom_ops/op_tile_crop_aot.cpp b/extension/llm/custom_ops/op_tile_crop_aot.cpp index 5aa98ee8d4a..7d89c462e1d 100644 --- a/extension/llm/custom_ops/op_tile_crop_aot.cpp +++ b/extension/llm/custom_ops/op_tile_crop_aot.cpp @@ -16,10 +16,30 @@ namespace torch { namespace executor { namespace native { +namespace { +template +auto to_et_arg(AType&& value) { + return executorch::extension::internal::type_convert( + std::forward(value)); +} + +at::Tensor& copy_et_result_to_out(Tensor& et_result, at::Tensor& out) { + auto converted_result = + executorch::extension::internal::type_convert( + et_result) + .call(); + at::native::resize_output(out, converted_result.sizes()); + out.copy_(converted_result); + return out; +} +} // namespace Tensor& tile_crop_out_no_context(const Tensor& input, int64_t tile_size, Tensor& out); +at::Tensor& +tile_crop_out_aten(const at::Tensor& input, int64_t tile_size, at::Tensor& out); + Tensor& tile_crop_out_no_context(const Tensor& input, int64_t tile_size, Tensor& out) { executorch::aten::RuntimeContext context{}; @@ -28,12 +48,21 @@ tile_crop_out_no_context(const Tensor& input, int64_t tile_size, Tensor& out) { at::Tensor tile_crop_aten(const at::Tensor& input, int64_t tile_size); +at::Tensor& tile_crop_out_aten( + const at::Tensor& input, + int64_t tile_size, + at::Tensor& out) { + auto input_et = to_et_arg(input); + auto out_et = to_et_arg(out); + auto& et_result = + tile_crop_out_no_context(input_et.call(), tile_size, out_et.call()); + return copy_et_result_to_out(et_result, out); +} + at::Tensor tile_crop_aten(const at::Tensor& input, int64_t tile_size) { // max_num_tiles = 4, num_channels = 3. auto output = at::empty({4, 3, tile_size, tile_size}); - - WRAP_TO_ATEN(torch::executor::native::tile_crop_out_no_context, 2) - (input, tile_size, output); + tile_crop_out_aten(input, tile_size, output); return output; } @@ -49,7 +78,5 @@ TORCH_LIBRARY(preprocess, m) { TORCH_LIBRARY_IMPL(preprocess, CompositeExplicitAutograd, m) { m.impl("tile_crop", torch::executor::native::tile_crop_aten); - m.impl( - "tile_crop.out", - WRAP_TO_ATEN(torch::executor::native::tile_crop_out_no_context, 2)); + m.impl("tile_crop.out", torch::executor::native::tile_crop_out_aten); } diff --git a/extension/llm/custom_ops/test_update_cache.py b/extension/llm/custom_ops/test_update_cache.py index 84a349c97f0..7edd273d8b9 100644 --- a/extension/llm/custom_ops/test_update_cache.py +++ b/extension/llm/custom_ops/test_update_cache.py @@ -431,3 +431,155 @@ def test_batched_update_kv_cache_more_updates(self): self._update_and_validate( k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos ) + + +class RecurrentGatedDeltaRuleTest(unittest.TestCase): + def _make_inputs( + self, + batch_size: int = 2, + num_heads: int = 3, + seq_len: int = 4, + k_head_dim: int = 5, + v_head_dim: int = 6, + ): + query = torch.randn(batch_size, num_heads, seq_len, k_head_dim) + key = torch.randn(batch_size, num_heads, seq_len, k_head_dim) + value = torch.randn(batch_size, num_heads, seq_len, v_head_dim) + g = torch.randn(batch_size, num_heads, seq_len) + beta = torch.sigmoid(torch.randn(batch_size, num_heads, seq_len)) + recurrent_state = torch.randn(batch_size, num_heads, k_head_dim, v_head_dim) + return query, key, value, g, beta, recurrent_state + + def _reference_recurrent_gated_delta_rule( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + recurrent_state: torch.Tensor, + ): + state = recurrent_state.clone() + output = torch.zeros_like(value) + + for token in range(query.size(2)): + g_t = g[:, :, token].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, token].unsqueeze(-1) + k_t = key[:, :, token] + v_t = value[:, :, token] + q_t = query[:, :, token] + + state = state * g_t + kv_mem = (state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + state = state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + output[:, :, token] = (state * q_t.unsqueeze(-1)).sum(dim=-2) + + return output, state + + def test_recurrent_gated_delta_rule_matches_reference(self): + torch.manual_seed(0) + + test_cases = ( + (2, 3, 4, 5, 6), + (1, 4, 7, 8, 3), + ) + + for case in test_cases: + with self.subTest(case=case): + ( + query, + key, + value, + g, + beta, + recurrent_state, + ) = self._make_inputs(*case) + + expected_output, expected_state = ( + self._reference_recurrent_gated_delta_rule( + query, + key, + value, + g, + beta, + recurrent_state, + ) + ) + + actual_state = recurrent_state.clone() + actual_output = torch.ops.llama.recurrent_gated_delta_rule( + query, + key, + value, + g, + beta, + actual_state, + ) + + self.assertTrue( + torch.allclose(actual_output, expected_output, atol=1e-5) + ) + self.assertTrue(torch.allclose(actual_state, expected_state, atol=1e-5)) + + def test_recurrent_gated_delta_rule_out_matches_reference(self): + torch.manual_seed(0) + + query, key, value, g, beta, recurrent_state = self._make_inputs() + expected_output, expected_state = self._reference_recurrent_gated_delta_rule( + query, + key, + value, + g, + beta, + recurrent_state, + ) + + actual_state = recurrent_state.clone() + actual_output = torch.empty_like(value) + returned_output = torch.ops.llama.recurrent_gated_delta_rule.out( + query, + key, + value, + g, + beta, + actual_state, + out=actual_output, + ) + + self.assertEqual(returned_output.data_ptr(), actual_output.data_ptr()) + self.assertTrue(torch.allclose(actual_output, expected_output, atol=1e-5)) + self.assertTrue(torch.allclose(actual_state, expected_state, atol=1e-5)) + + def test_recurrent_gated_delta_rule_chunked_matches_full_sequence(self): + torch.manual_seed(0) + + query, key, value, g, beta, recurrent_state = self._make_inputs(seq_len=6) + + full_state = recurrent_state.clone() + full_output = torch.ops.llama.recurrent_gated_delta_rule( + query, + key, + value, + g, + beta, + full_state, + ) + + chunk_state = recurrent_state.clone() + chunk_outputs = [] + for start, end in ((0, 2), (2, 5), (5, 6)): + chunk_outputs.append( + torch.ops.llama.recurrent_gated_delta_rule( + query[:, :, start:end, :], + key[:, :, start:end, :], + value[:, :, start:end, :], + g[:, :, start:end], + beta[:, :, start:end], + chunk_state, + ) + ) + + chunked_output = torch.cat(chunk_outputs, dim=2) + self.assertTrue(torch.allclose(chunked_output, full_output, atol=1e-5)) + self.assertTrue(torch.allclose(chunk_state, full_state, atol=1e-5))