diff --git a/BUILDING.md b/BUILDING.md index 7c3177866..b188ca6c4 100644 --- a/BUILDING.md +++ b/BUILDING.md @@ -19,7 +19,7 @@ Follow these steps: 1. Fork and clone the GitHub [coremltools repository](https://github.com/apple/coremltools). 2. Run the [build.sh](scripts/build.sh) script to build `coremltools`. - * By default this script uses Python 3.7, but you can include `--python=3.8` (or `3.9`, `3.10`, `3.11`, `3.12`) as a argument to change the Python version. + * By default this script uses Python 3.7, but you can include `--python=3.8` (or `3.9`, `3.10`, `3.11`, `3.12`, `3.13`) as a argument to change the Python version. * The script creates a new `build` folder with the coremltools distribution, and a `dist` folder with Python wheel files. 3. Run the [test.sh](scripts/test.sh) script to test the build. @@ -45,7 +45,7 @@ The following build targets help you configure the development environment. If y * `test_slow` | Run all non-fast tests. * `wheel` | Build wheels in release mode. -The script uses Python 3.7, but you can include `--python=3.8` (or `3.9`, `3.10`, `3.11`, `3.12`) as a argument to change the Python version. +The script uses Python 3.7, but you can include `--python=3.8` (or `3.9`, `3.10`, `3.11`, `3.12`, `3.13`) as a argument to change the Python version. ## Resources diff --git a/CMakeLists.txt b/CMakeLists.txt index d4625252a..5777bac01 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -70,7 +70,7 @@ add_library(modelpackage modelpackage/src/utils/JsonMap.cpp modelpackage/src/ModelPackagePython.cpp ) - + target_compile_definitions(modelpackage PRIVATE CPU_ONLY=1 @@ -197,8 +197,13 @@ set(KMEANS_DIR "${PROJECT_SOURCE_DIR}/deps/kmeans1d") execute_process( COMMAND python3 setup.py build_ext --inplace WORKING_DIRECTORY ${KMEANS_DIR} + RESULT_VARIABLE KMEANS1D_BUILD_STATUS ) +if(NOT KMEANS1D_BUILD_STATUS EQUAL 0) + message(FATAL_ERROR "Could not build kmeans1d dependency") +endif() + # Somehow Python's setuptools is building this shared object file so that it tries to load the C++ # standard library using an rpath that only exist on the build machine. Change that so it gets # loaded from the standard location. diff --git a/coremltools/converters/mil/frontend/_utils.py b/coremltools/converters/mil/frontend/_utils.py index 4eb481ef4..a4a764cec 100644 --- a/coremltools/converters/mil/frontend/_utils.py +++ b/coremltools/converters/mil/frontend/_utils.py @@ -126,7 +126,12 @@ def pymil_broadcast_to(tensor: Var, shape: Union[Var, VARIABLE_SHAPE_TYPE], name if any_symbolic(tensor.shape) or shape_var.val is None: tensor_shape = mb.shape(x=tensor) - reps = mb.real_div(x=shape_var, y=tensor_shape) + reps = mb.select( + cond = mb.equal(x=shape_var, y=-1), + a = tensor_shape, + b = shape_var, + ) + reps = mb.real_div(x=reps, y=tensor_shape) reps = mb.cast(x=reps, dtype="int32") res = mb.tile(x=tensor, reps=reps, name=name) else: diff --git a/coremltools/converters/mil/frontend/torch/ssa_passes/torch_upsample_to_core_upsample.py b/coremltools/converters/mil/frontend/torch/ssa_passes/torch_upsample_to_core_upsample.py index 9344c38f2..9cac8ba27 100644 --- a/coremltools/converters/mil/frontend/torch/ssa_passes/torch_upsample_to_core_upsample.py +++ b/coremltools/converters/mil/frontend/torch/ssa_passes/torch_upsample_to_core_upsample.py @@ -42,12 +42,67 @@ def _torch_upsample_to_core_upsample_block(block): if op.op_type in target_ops: if _try_replace_with_core_upsample(op): - logger.info("Successfully map {} to core upsample".format(op.op_type)) + msg = f"Successfully map {op.op_type} to core upsample" + logger.info(msg) else: raise ValueError("Unable to map {} to core upsample".format(op.op_type)) - -def _try_get_upsample_factor(output_size): +def _try_get_upsample_factor_pattern_2(output_size, expected_gather_indices, target_op): + """ + This is the pattern corresponds to the python source code: + + class UpsampleBilinear(nn.Module): + def forward(self, x): + b, c, h, w = x.shape + return F.interpolate(x, size=(h*2, w*2), mode='bilinear', align_corners=False) + + The resulting pymil program is: + + function[CoreML5](%x: (1, 3, is0, is1, fp32)(Tensor)) { + block0() { + %3_shape: (4,int32)^(Tensor) = shape(x=%x, name="3_shape") + %gather_0: (int32)^(Scalar) = gather(x=%3_shape, indices=2, axis=0, name="gather_0") + %6_shape: (4,int32)^(Tensor) = shape(x=%x, name="6_shape") + %gather_1: (int32)^(Scalar) = gather(x=%6_shape, indices=3, axis=0, name="gather_1") + %9: (int32)(Scalar) = mul(x=%gather_0, y=2, name="9") + %10: (int32)(Scalar) = cast(x=%9, dtype="int32", name="10") + %12: (int32)(Scalar) = mul(x=%gather_1, y=2, name="12") + %13: (int32)(Scalar) = cast(x=%12, dtype="int32", name="13") + %17: (1, 3, is2, is3, fp32)(Tensor) = torch_upsample_bilinear(x=%x, output_height=%10, output_width=%13, align_corners=False, name="17") + + We do a pattern matching to extract the scale value. + """ + # cast op + op = output_size + if op.op_type != "cast" or op.dtype.val != "int32": + return None + + # mul op + mul_op = op.x.op + if mul_op.op_type != "mul": + return None + mul_op_y = mul_op.y + + # gather op + gather_op = mul_op.x.op + if gather_op.op_type != "gather": + return None + if gather_op.indices.val != expected_gather_indices: + return None + if gather_op.axis.val != 0: + return None + + # shape op + shape_op = gather_op.x.op + if shape_op.op_type != "shape": + return None + if shape_op.x != target_op: + return None + + return mul_op_y.val + + +def _try_get_upsample_factor_pattern_1(output_size): op = output_size # If the output has value, then the upsample op itself is derived from the upsample_1d op, # so we can just return scale factor 1 for that case @@ -103,9 +158,15 @@ def _try_replace_with_core_upsample(op): assert op.op_type in target_ops # 2d upsampling - if op.op_type in ["torch_upsample_nearest_neighbor", "torch_upsample_bilinear"]: - scales_h = _try_get_upsample_factor(op.output_height.op) - scales_w = _try_get_upsample_factor(op.output_width.op) + if op.op_type in target_ops: + + # try to resolve the scaling factor - pattern 1 + scales_h = _try_get_upsample_factor_pattern_1(op.output_height.op) + scales_w = _try_get_upsample_factor_pattern_1(op.output_width.op) + + if scales_h is None or scales_w is None: + scales_h = _try_get_upsample_factor_pattern_2(op.output_height.op, 2, op.x) + scales_w = _try_get_upsample_factor_pattern_2(op.output_width.op, 3, op.x) if scales_h is None or scales_w is None: return False diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_export_conversion_api.py b/coremltools/converters/mil/frontend/torch/test/test_torch_export_conversion_api.py index 2242885b8..324cc7cfb 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_export_conversion_api.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_export_conversion_api.py @@ -334,6 +334,7 @@ def forward(self, x): assert mlmodel.user_defined_metadata[_METADATA_SOURCE_DIALECT] == dialect_name +@pytest.mark.skipif((version_info.major, version_info.minor) == (3, 13), reason="rdar://158079341") class TestExecuTorchExamples(TorchBaseTest): @pytest.mark.parametrize( "compute_unit, backend, frontend, dynamic", diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index 3e40351b9..fa3c398a6 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -2668,6 +2668,41 @@ def forward(self, x): compute_unit=compute_unit, ) + @pytest.mark.parametrize( + "compute_unit, backend, frontend", + itertools.product(compute_units, backends, frontends), + ) + def test_upsample_with_shape_gather_pattern(self, compute_unit, backend, frontend): + if frontend == TorchFrontend.TORCHEXPORT: + pytest.xfail("CoreML model not runnable for the torch export frontend.") + + input_shape = (1, 3, 32, 32) + + class UpsampleBilinear(nn.Module): + def forward(self, x): + b, c, h, w = x.shape + return nn.functional.interpolate( + x, size=(h * 2, w * 2), mode="bilinear", align_corners=False + ) + + model = UpsampleBilinear().eval() + + h_dim = torch.export.Dim(name="height", min=16, max=128) + w_dim = torch.export.Dim(name="width", min=16, max=128) + torch_export_dynamic_shapes = {"x": {2: h_dim, 3: w_dim}} + + self.run_compare_torch( + input_shape, + model, + frontend=frontend, + backend=backend, + compute_unit=compute_unit, + converter_input_type=[ + TensorType(shape=(1, 3, ct.RangeDim(16, 128), ct.RangeDim(16, 128))), + ], + torch_export_dynamic_shapes=torch_export_dynamic_shapes, + ) + @pytest.mark.parametrize( "compute_unit, backend, frontend, output_size", itertools.product(compute_units, backends, frontends, [10, 170]), @@ -5234,7 +5269,42 @@ def forward(self, x, y): input_shapes, model, compute_unit=compute_unit, backend=backend, frontend=frontend ) + @pytest.mark.parametrize( + "compute_unit, backend, frontend, input_shape", + itertools.product( + compute_units, + backends, + frontends, + [ + (ct.RangeDim(3, 21), ), + (15, ) + ] + ), + ) + def test_expand_dynamic_shape4(self, compute_unit, backend, frontend, input_shape): + if frontend in TORCH_EXPORT_BASED_FRONTENDS: + pytest.xfail( + "torch.export refuses to make size-1 dim dynamic, " + "and cannot expand one dynamic dimension into another dynamic dimension" + ) + + class TestModel(nn.Module): + def forward(self, x): + return x.reshape(-1, 1, 3).expand(-1, 7, -1) + + converter_input_type = [ct.TensorType(name = "x", shape=input_shape, dtype=types.fp32)] + model = TestModel() + self.run_compare_torch( + torch.rand(15), + model, + input_as_shape=False, + converter_input_type=converter_input_type, + frontend=frontend, + backend=backend, + compute_unit=compute_unit, + ) + class TestExpandDims(TorchBaseTest): @pytest.mark.parametrize( "compute_unit, backend, frontend, rank_and_axis", diff --git a/coremltools/test/ml_program/test_utils.py b/coremltools/test/ml_program/test_utils.py index 1cbfde3b8..a0878d602 100644 --- a/coremltools/test/ml_program/test_utils.py +++ b/coremltools/test/ml_program/test_utils.py @@ -9,6 +9,7 @@ import platform import shutil import tempfile +from sys import version_info from typing import Dict, Tuple import numpy as np @@ -1496,8 +1497,11 @@ def validate_inference(multifunction_mlpackage_path: str) -> None: shutil.rmtree(multifunction_mlpackage_path) +@pytest.mark.skipif( + (version_info.major, version_info.minor) == (3, 13), + reason="rdar://157488825 (Python 3.13 Unit Test Segmentation Fault)", +) class TestBisectModel: - @staticmethod def check_spec_op_type(model_path, expected_ops): spec = load_spec(model_path) diff --git a/coremltools/version.py b/coremltools/version.py index e480a3e21..c52268d7d 100644 --- a/coremltools/version.py +++ b/coremltools/version.py @@ -4,4 +4,4 @@ # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause -__version__ = "9.0b1" # VERSION_STRING +__version__ = "9.0" # VERSION_STRING diff --git a/docs-guides/source/installing-coremltools.md b/docs-guides/source/installing-coremltools.md index 82209a4ba..9dc569fe2 100644 --- a/docs-guides/source/installing-coremltools.md +++ b/docs-guides/source/installing-coremltools.md @@ -4,7 +4,7 @@ This page describes how to install the [`coremltools`](https://github.com/apple/ ```{admonition} Supported Python and MacOS Versions -The current version of coremltools ([version 8.0](https://github.com/apple/coremltools)) includes wheels for Python 3.7, 3.8, 3.9, 3.10, 3.11, and 3.12. The last stable release of coremltools to support Python 2 is version 4.0. +The current version of coremltools ([version 9.0b1](https://github.com/apple/coremltools)) includes wheels for Python 3.7, 3.8, 3.9, 3.10, 3.11, 3.12, 3.13. The last stable release of coremltools to support Python 2 is version 4.0. The supported MacOS versions are as follows: @@ -19,7 +19,7 @@ The supported MacOS versions are as follows: If you are using macOS, you should already be familiar with the [Mac Terminal app command line](https://developer.apple.com/library/archive/documentation/OpenSource/Conceptual/ShellScripting/CommandLInePrimer/CommandLine.html#//apple_ref/doc/uid/TP40004268-CH271-BBCBEAJD "Command Line Primer") to perform tasks such as installations and updates. If you are using Linux, you should already be familiar with [basic Shell commands in Linux](https://www.geeksforgeeks.org/basic-shell-commands-in-linux/). ``` -Before installing coremltools, you need [Python](https://www.python.org/downloads/ "Python Downloads") and the [`pip`](https://pip.pypa.io/en/stable/) installer. +Before installing coremltools, you need [Python](https://www.python.org/downloads/ "Python Downloads") and the [`pip`](https://pip.pypa.io/en/stable/) installer. The `coremltools` package supports [Python 3](https://www.python.org/download/releases/3.0/). We recommend that you install Python 3.6 or newer. Use a Python package manager such as [Conda](https://docs.conda.io/en/latest/index.html) or [venv](https://docs.python.org/3/library/venv.html) to install the newest version of Python and other dependencies. [Conda](https://docs.conda.io/en/latest/index.html) is recommended because it is the most reliable way to install all required dependencies. @@ -82,7 +82,7 @@ python -m venv coremltools-venv source coremltools-venv/bin/activate ``` -4. Follow the instructions in [Install Core ML Tools](#install-core-ml-tools). +4. Follow the instructions in [Install Core ML Tools](#install-core-ml-tools). ## Install Core ML Tools @@ -103,17 +103,17 @@ The continuous integration (CI) system linked to the `coremltools` repo builds a To access the wheel for a particular `coremltools` release, follow these steps: 1. Go to the [`coremltools` repository](https://github.com/apple/coremltools) on GitHub, scroll down to the **README.md** heading, and click the **build passing** button. The **Branches** tab appears: - + ![Branches tab](images/repo-readme-build-passing-button-annot.png) - + ![Branches passed](images/repo-branches-passed-button.png) 2. Click the **passed** button to show the **Pipeline** tab: - + ![Pipeline tab](images/repo-build-wheel-selected.png) 3. Click a wheel in the **Build** column. For example, in the previous figure, the **build_wheel_macos_py38** wheel is highlighted for clicking. After clicking a wheel, the raw job log appears, with the **Download** and **Browse** buttons in the right column: - + ![Download and Browse](images/repo-job-artifacts.png) 4. Click the **Download** button to download the `dist` folder with the wheel files. diff --git a/reqs/build.pip b/reqs/build.pip index 9e99109ac..4160eb2e5 100644 --- a/reqs/build.pip +++ b/reqs/build.pip @@ -1,8 +1,11 @@ numpy==1.21.0; platform_machine == "arm64" and python_version < "3.9" numpy<1.20; platform_machine != "arm64" and python_version < "3.9" -numpy==2.0.0; python_version >= "3.9" +numpy==2.0.0; python_version >= "3.9" and python_version < "3.13" +numpy==2.1.0; python_version >= "3.13" +protobuf pytest +setuptools; python_version >= "3.13" six sympy tqdm diff --git a/reqs/common_test_packages.pip b/reqs/common_test_packages.pip index fa0ee6887..c7209f85a 100644 --- a/reqs/common_test_packages.pip +++ b/reqs/common_test_packages.pip @@ -1,5 +1,6 @@ boto3==1.14.8; python_version < '3.12' -boto3==1.24.22; python_version >= '3.12' +boto3==1.24.22; python_version == '3.12' +boto3==1.39.3; python_version >= '3.13' configparser future @@ -7,7 +8,6 @@ future olefile==0.44 pandas parameterized==0.8.1 -protobuf pillow pytest==7.1.2 pytest-cov diff --git a/reqs/pytorch.pip b/reqs/pytorch.pip index a16bdf1ba..4b94e3a53 100644 --- a/reqs/pytorch.pip +++ b/reqs/pytorch.pip @@ -5,7 +5,7 @@ torchaudio==2.2.0; platform_machine != "arm64" torchvision==0.17.0; platform_machine != "arm64" # Torch dependencies for ARM -torch==2.7.0; platform_machine == "arm64" +torch>=2.2.0,<=2.7.0; platform_machine == "arm64" torchaudio>=2.2.0; platform_machine == "arm64" torchvision>=0.17.0; platform_machine == "arm64" torchsr==1.0.4; platform_machine == "arm64" diff --git a/reqs/test_executorch.pip b/reqs/test_executorch.pip index 90a9cb834..2bffd56a3 100644 --- a/reqs/test_executorch.pip +++ b/reqs/test_executorch.pip @@ -3,4 +3,4 @@ # Warning: Starting from ExecuTorch 0.6.0, coremltools is added as a dependency # so we need to re-install built-from-source coremltools after pip install ExecuTorch -executorch>=0.6.0; platform_machine == "arm64" and python_version >= '3.10' +executorch>=0.6.0; platform_machine == "arm64" and python_version >= '3.10' and python_version < '3.13' diff --git a/setup.py b/setup.py index c65f4f0e6..665d254a5 100755 --- a/setup.py +++ b/setup.py @@ -106,6 +106,7 @@ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Topic :: Scientific/Engineering", "Topic :: Software Development", ],