Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
e7b38a3
Cortex-M backend: Support standalone clamp-type activations (#18767)
xingguo01 Apr 23, 2026
2d995bc
Arm backend: Fix quantized constant-folding for aten.cat lists (#18971)
perheld Apr 17, 2026
8a77f9b
Format third-party/CMakeLists.txt using cmake-format (#18533)
aksharabhardwaj766-commits Apr 23, 2026
3ec63f4
Ignored Module tests: provide required input tensor (#19028)
psiddh Apr 23, 2026
6d23e41
Extract shared multifunction PTE utilities to utils.py (#19035)
YIWENX14 Apr 23, 2026
7b5dcc1
Add add-relu fusion in the quantizer
mcremon-meta Apr 23, 2026
f9f29e7
Android: improve error diagnostics for LlmModule and exceptions (#19092)
psiddh Apr 23, 2026
4a69750
Add Half (float16) support to slim ScalarType enum (#18959) (#18959)
digantdesai Apr 23, 2026
edb8c98
Validate XNNPACK tensor flags are valid (#19102)
lucylq Apr 24, 2026
75b31bb
Fix smollm2 alias to point at SmolLM2-135M (v2) instead of SmolLM-135…
Copilot Apr 24, 2026
c3f3d12
Add tryTo evalue accessors (#19039)
lucylq Apr 24, 2026
eef7921
Widen resolve_max_new_tokens parameters to int64_t and rename for cla…
kirklandsign Apr 24, 2026
b6cec38
skip cuda operations when running qwen 3.5 moe on other backend (#19095)
Gasoonjia Apr 24, 2026
b6a47aa
Arm backend: Disable fusing of TOSA ops (#19066)
oscarandersson8218 Apr 24, 2026
c5c5b3a
Arm backend: Add util for symbolic range eval (#19108)
oscarandersson8218 Apr 24, 2026
98a1d66
Remove un-used copy from building dockers (#18868)
Erik-Lundell Apr 24, 2026
476a7ef
Add recurrent gated delta rule custom op for Qwen3.5 attention (#18088)
Phineas1500 Apr 24, 2026
cd45d6a
Merge branch 'polycam' into main
jgibson2 Apr 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions .ci/docker/build.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -94,11 +95,6 @@ BUILD_DOCS=1
# Copy requirements-lintrunner.txt from root to here
cp ../../requirements-lintrunner.txt ./

# Copy arm setup script from root to here
# TODO(huydhn): Figure out a way to rebuild the Docker image automatically
# with a new image hash when the content here is updated
cp -r ../../examples/arm/ ./arm

docker build \
--no-cache \
--progress=plain \
Expand Down
11 changes: 9 additions & 2 deletions backends/aoti/slim/c10/core/ScalarType.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ enum class ScalarType : int8_t {
Short = 2, // int16_t
Int = 3, // int32_t
Long = 4, // int64_t
// Half = 5, // float16 - not currently needed
Half = 5, // float16
Float = 6, // float
// Double = 7, // double - not currently needed
// ComplexHalf = 8,
Expand All @@ -48,6 +48,7 @@ constexpr ScalarType kChar = ScalarType::Char;
constexpr ScalarType kShort = ScalarType::Short;
constexpr ScalarType kInt = ScalarType::Int;
constexpr ScalarType kLong = ScalarType::Long;
constexpr ScalarType kHalf = ScalarType::Half;
constexpr ScalarType kFloat = ScalarType::Float;
constexpr ScalarType kBool = ScalarType::Bool;
constexpr ScalarType kBFloat16 = ScalarType::BFloat16;
Expand All @@ -67,6 +68,8 @@ inline size_t elementSize(ScalarType t) {
return sizeof(int32_t);
case ScalarType::Long:
return sizeof(int64_t);
case ScalarType::Half:
return 2; // sizeof(__half) = 2 bytes
case ScalarType::Float:
return sizeof(float);
case ScalarType::Bool:
Expand All @@ -93,6 +96,8 @@ inline const char* toString(ScalarType t) {
return "Int";
case ScalarType::Long:
return "Long";
case ScalarType::Half:
return "Half";
case ScalarType::Float:
return "Float";
case ScalarType::Bool:
Expand All @@ -110,7 +115,8 @@ inline const char* toString(ScalarType t) {
/// @param t The scalar type to check.
/// @return true if the scalar type is floating point, false otherwise.
inline bool isFloatingType(ScalarType t) {
return t == ScalarType::Float || t == ScalarType::BFloat16;
return t == ScalarType::Half || t == ScalarType::Float ||
t == ScalarType::BFloat16;
}

/// Checks if the scalar type is an integral type (including bool optionally).
Expand Down Expand Up @@ -149,6 +155,7 @@ inline bool isValidScalarType(ScalarType t) {
case ScalarType::Short:
case ScalarType::Int:
case ScalarType::Long:
case ScalarType::Half:
case ScalarType::Float:
case ScalarType::Bool:
case ScalarType::BFloat16:
Expand Down
35 changes: 35 additions & 0 deletions backends/aoti/slim/c10/core/test/test_scalar_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ const std::vector<ScalarTypeTestData> kAllScalarTypes = {
{ScalarType::Short, 2, 2, "Short", false, true, true, false},
{ScalarType::Int, 3, 4, "Int", false, true, true, false},
{ScalarType::Long, 4, 8, "Long", false, true, true, false},
{ScalarType::Half, 5, 2, "Half", true, false, false, false},
{ScalarType::Float, 6, 4, "Float", true, false, false, false},
{ScalarType::Bool, 11, 1, "Bool", false, false, true, true},
{ScalarType::BFloat16, 15, 2, "BFloat16", true, false, false, false},
Expand Down Expand Up @@ -128,6 +129,10 @@ TEST_F(ScalarTypeConstantsTest, KLongConstant) {
EXPECT_EQ(kLong, ScalarType::Long);
}

TEST_F(ScalarTypeConstantsTest, KHalfConstant) {
EXPECT_EQ(kHalf, ScalarType::Half);
}

TEST_F(ScalarTypeConstantsTest, KFloatConstant) {
EXPECT_EQ(kFloat, ScalarType::Float);
}
Expand Down Expand Up @@ -185,6 +190,10 @@ TEST_F(ElementSizeConsistencyTest, LongMatchesSizeofInt64) {
EXPECT_EQ(elementSize(ScalarType::Long), sizeof(int64_t));
}

TEST_F(ElementSizeConsistencyTest, HalfIs2Bytes) {
EXPECT_EQ(elementSize(ScalarType::Half), 2);
}

TEST_F(ElementSizeConsistencyTest, FloatMatchesSizeofFloat) {
EXPECT_EQ(elementSize(ScalarType::Float), sizeof(float));
}
Expand All @@ -196,3 +205,29 @@ TEST_F(ElementSizeConsistencyTest, BoolMatchesSizeofBool) {
TEST_F(ElementSizeConsistencyTest, BFloat16MatchesSizeofBFloat16) {
EXPECT_EQ(elementSize(ScalarType::BFloat16), sizeof(BFloat16));
}

// =============================================================================
// isValidScalarType Tests
// =============================================================================

class IsValidScalarTypeTest : public ::testing::Test {};

TEST_F(IsValidScalarTypeTest, HalfIsValid) {
EXPECT_TRUE(isValidScalarType(ScalarType::Half));
}

TEST_F(IsValidScalarTypeTest, AllSupportedTypesAreValid) {
EXPECT_TRUE(isValidScalarType(ScalarType::Byte));
EXPECT_TRUE(isValidScalarType(ScalarType::Char));
EXPECT_TRUE(isValidScalarType(ScalarType::Short));
EXPECT_TRUE(isValidScalarType(ScalarType::Int));
EXPECT_TRUE(isValidScalarType(ScalarType::Long));
EXPECT_TRUE(isValidScalarType(ScalarType::Half));
EXPECT_TRUE(isValidScalarType(ScalarType::Float));
EXPECT_TRUE(isValidScalarType(ScalarType::Bool));
EXPECT_TRUE(isValidScalarType(ScalarType::BFloat16));
}

TEST_F(IsValidScalarTypeTest, UndefinedIsNotValid) {
EXPECT_FALSE(isValidScalarType(ScalarType::Undefined));
}
68 changes: 51 additions & 17 deletions backends/arm/_passes/fuse_constant_ops_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import Set, Type
from collections.abc import Mapping
from typing import Sequence, Set, Type

import torch._export.utils
import torch.fx
Expand All @@ -18,6 +19,7 @@
from executorch.backends.arm._passes.fuse_equal_placeholders_pass import (
FuseEqualPlaceholdersPass,
)
from executorch.backends.arm.tosa.dialect.shape import meta_has_shape_mark
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
from executorch.backends.transforms.utils import (
create_constant_placeholder,
Expand Down Expand Up @@ -53,6 +55,36 @@ def __init__(self, exported_program: ExportedProgram, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.exported_program = exported_program

@staticmethod
def _is_tosa_dialect_op(target) -> bool:
target_str = str(target)
return (
"executorch.exir.dialects.backend._ops.tosa." in target_str
or "<EdgeOpOverload: tosa." in target_str
)

@staticmethod
def _arg_contains_symbolic_shape(arg) -> bool:
if isinstance(arg, torch.fx.Node):
if meta_has_shape_mark(arg.meta):
return True
return FuseConstantArgsPass._arg_contains_symbolic_shape(
arg.meta.get("val")
)
if isinstance(arg, torch.SymInt):
return True
if isinstance(arg, Mapping):
return any(
FuseConstantArgsPass._arg_contains_symbolic_shape(k)
or FuseConstantArgsPass._arg_contains_symbolic_shape(v)
for k, v in arg.items()
)
if isinstance(arg, Sequence) and not isinstance(arg, (str, bytes)):
return any(
FuseConstantArgsPass._arg_contains_symbolic_shape(v) for v in arg
)
return False

def _propagate_special_dtype(self, from_nodes, to_node, data):
"""Propagate special dtype meta if it exists."""
special_dtypes = set()
Expand Down Expand Up @@ -83,21 +115,24 @@ def _fuse_nodes(self, node) -> bool:
input_nodes = list(node.all_input_nodes)
qparams = node.meta.get("input_qparams", None)

def resolve_arg(arg):
def resolve_arg(arg, arg_index=None):
qparam = (
qparams.get(arg_index) if qparams and arg_index is not None else None
)
if isinstance(arg, torch.fx.Node) and arg in input_nodes:
idx = input_nodes.index(arg)
t = get_param_tensor(self.exported_program, arg)
# Check if qparams exist for this arg
if qparams and idx in qparams.keys():
t = qparams[idx].dequantize_value(t)
if qparam is not None:
t = qparam.dequantize_value(t)
return t
if isinstance(arg, tuple):
return tuple(resolve_arg(x) for x in arg)
return tuple(resolve_arg(x, arg_index) for x in arg)
if isinstance(arg, list):
return [resolve_arg(x) for x in arg]
return [resolve_arg(x, arg_index) for x in arg]
return arg

new_args = tuple(resolve_arg(a) for a in node.args)
new_args = tuple(
resolve_arg(arg, arg_index) for arg_index, arg in enumerate(node.args)
)
new_kwargs = {k: resolve_arg(v) for k, v in node.kwargs.items()}

data = node.target(*new_args, **new_kwargs)
Expand Down Expand Up @@ -139,13 +174,13 @@ def call(self, graph_module):
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue
if node.target in [
exir_ops.backend.tosa.MATMUL.default,
exir_ops.backend.tosa.RESCALE.default,
exir_ops.backend.tosa.RESIZE.default,
exir_ops.backend.tosa.TABLE.default,
exir_ops.backend.tosa.TRANSPOSE.default,
]:
# Don't fuse TOSA dialect ops as they do not have eager forward functions.
# Also don't fuse ops whose explicit args/kwargs include symbolic shape values.
if (
self._is_tosa_dialect_op(node.target)
or self._arg_contains_symbolic_shape(node.args)
or self._arg_contains_symbolic_shape(node.kwargs)
):
continue

input_nodes = node.all_input_nodes
Expand All @@ -161,7 +196,6 @@ def call(self, graph_module):
)
if not all(input_nodes_constant):
continue

try:
did_fuse = self._fuse_nodes(node)
if did_fuse:
Expand Down
15 changes: 12 additions & 3 deletions backends/arm/_passes/rewrite_conv_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
get_input_qparams,
get_output_qparams,
)
from executorch.backends.arm._passes.symbolic_value_range import (
evaluate_symbolic_expr_values,
)
from executorch.backends.arm.constants import HWCM_ORDER, NHWC_INVERSE_ORDER
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
from executorch.backends.arm.tosa.specification import get_context_shape_env
Expand Down Expand Up @@ -83,16 +86,22 @@ def _adjust_pad_if_needed(

if isinstance(mod_remainder, torch.SymInt):
shape_env = get_context_shape_env()
value_ranges = shape_env.bound_sympy(mod_remainder.node.expr)
mod_remainder_upper = int(value_ranges.upper)
exact_values = evaluate_symbolic_expr_values(
mod_remainder.node.expr, shape_env
)
if exact_values is not None:
mod_remainder_upper = max(exact_values)
else:
value_ranges = shape_env.bound_sympy(mod_remainder.node.expr)
mod_remainder_upper = int(value_ranges.upper)
if mod_remainder_upper == 0:
mod_remainder = 0
else:
mod_remainder_upper = mod_remainder

if mod_remainder_upper > pad:
raise RuntimeError(
"This case should be handled by the SizeAdjustInputPass, is it enabled?"
"This case should be handled by the SizeAdjustInputPass, is it enabled?\n"
)
return pad - mod_remainder

Expand Down
7 changes: 7 additions & 0 deletions backends/arm/_passes/size_adjust_input_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
from typing import cast, Sequence, Set, Type, TypeAlias

import torch
import torch.fx
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import (
Expand All @@ -12,6 +13,9 @@
)
from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass
from executorch.backends.arm._passes.rewrite_max_pool2d_pass import RewriteMaxPool2dPass
from executorch.backends.arm._passes.symbolic_value_range import (
evaluate_symbolic_expr_values,
)
from executorch.backends.arm.tosa.specification import get_context_shape_env
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
Expand Down Expand Up @@ -49,6 +53,9 @@ def _greater_than(input: SymIntLike, other: int) -> bool | torch.SymBool:
"""Returns whether an int or SymInt is greater than another value."""
if isinstance(input, torch.SymInt):
shape_env = get_context_shape_env()
exact_values = evaluate_symbolic_expr_values(input.node.expr, shape_env)
if exact_values is not None:
return max(exact_values) > other
value_ranges = shape_env.bound_sympy(input.node.expr)
return value_ranges.upper > other
else:
Expand Down
Loading