Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions BUILDING.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Follow these steps:
1. Fork and clone the GitHub [coremltools repository](https://github.com/apple/coremltools).

2. Run the [build.sh](scripts/build.sh) script to build `coremltools`.
* By default this script uses Python 3.7, but you can include `--python=3.8` (or `3.9`, `3.10`, `3.11`, `3.12`) as a argument to change the Python version.
* By default this script uses Python 3.7, but you can include `--python=3.8` (or `3.9`, `3.10`, `3.11`, `3.12`, `3.13`) as a argument to change the Python version.
* The script creates a new `build` folder with the coremltools distribution, and a `dist` folder with Python wheel files.

3. Run the [test.sh](scripts/test.sh) script to test the build.
Expand All @@ -45,7 +45,7 @@ The following build targets help you configure the development environment. If y
* `test_slow` | Run all non-fast tests.
* `wheel` | Build wheels in release mode.

The script uses Python 3.7, but you can include `--python=3.8` (or `3.9`, `3.10`, `3.11`, `3.12`) as a argument to change the Python version.
The script uses Python 3.7, but you can include `--python=3.8` (or `3.9`, `3.10`, `3.11`, `3.12`, `3.13`) as a argument to change the Python version.

## Resources

Expand Down
7 changes: 6 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ add_library(modelpackage
modelpackage/src/utils/JsonMap.cpp
modelpackage/src/ModelPackagePython.cpp
)

target_compile_definitions(modelpackage
PRIVATE
CPU_ONLY=1
Expand Down Expand Up @@ -197,8 +197,13 @@ set(KMEANS_DIR "${PROJECT_SOURCE_DIR}/deps/kmeans1d")
execute_process(
COMMAND python3 setup.py build_ext --inplace
WORKING_DIRECTORY ${KMEANS_DIR}
RESULT_VARIABLE KMEANS1D_BUILD_STATUS
)

if(NOT KMEANS1D_BUILD_STATUS EQUAL 0)
message(FATAL_ERROR "Could not build kmeans1d dependency")
endif()

# Somehow Python's setuptools is building this shared object file so that it tries to load the C++
# standard library using an rpath that only exist on the build machine. Change that so it gets
# loaded from the standard location.
Expand Down
7 changes: 6 additions & 1 deletion coremltools/converters/mil/frontend/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,12 @@ def pymil_broadcast_to(tensor: Var, shape: Union[Var, VARIABLE_SHAPE_TYPE], name

if any_symbolic(tensor.shape) or shape_var.val is None:
tensor_shape = mb.shape(x=tensor)
reps = mb.real_div(x=shape_var, y=tensor_shape)
reps = mb.select(
cond = mb.equal(x=shape_var, y=-1),
a = tensor_shape,
b = shape_var,
)
reps = mb.real_div(x=reps, y=tensor_shape)
reps = mb.cast(x=reps, dtype="int32")
res = mb.tile(x=tensor, reps=reps, name=name)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,67 @@ def _torch_upsample_to_core_upsample_block(block):

if op.op_type in target_ops:
if _try_replace_with_core_upsample(op):
logger.info("Successfully map {} to core upsample".format(op.op_type))
msg = f"Successfully map {op.op_type} to core upsample"
logger.info(msg)
else:
raise ValueError("Unable to map {} to core upsample".format(op.op_type))


def _try_get_upsample_factor(output_size):
def _try_get_upsample_factor_pattern_2(output_size, expected_gather_indices, target_op):
"""
This is the pattern corresponds to the python source code:

class UpsampleBilinear(nn.Module):
def forward(self, x):
b, c, h, w = x.shape
return F.interpolate(x, size=(h*2, w*2), mode='bilinear', align_corners=False)

The resulting pymil program is:

function[CoreML5](%x: (1, 3, is0, is1, fp32)(Tensor)) {
block0() {
%3_shape: (4,int32)^(Tensor) = shape(x=%x, name="3_shape")
%gather_0: (int32)^(Scalar) = gather(x=%3_shape, indices=2, axis=0, name="gather_0")
%6_shape: (4,int32)^(Tensor) = shape(x=%x, name="6_shape")
%gather_1: (int32)^(Scalar) = gather(x=%6_shape, indices=3, axis=0, name="gather_1")
%9: (int32)(Scalar) = mul(x=%gather_0, y=2, name="9")
%10: (int32)(Scalar) = cast(x=%9, dtype="int32", name="10")
%12: (int32)(Scalar) = mul(x=%gather_1, y=2, name="12")
%13: (int32)(Scalar) = cast(x=%12, dtype="int32", name="13")
%17: (1, 3, is2, is3, fp32)(Tensor) = torch_upsample_bilinear(x=%x, output_height=%10, output_width=%13, align_corners=False, name="17")

We do a pattern matching to extract the scale value.
"""
# cast op
op = output_size
if op.op_type != "cast" or op.dtype.val != "int32":
return None

# mul op
mul_op = op.x.op
if mul_op.op_type != "mul":
return None
mul_op_y = mul_op.y

# gather op
gather_op = mul_op.x.op
if gather_op.op_type != "gather":
return None
if gather_op.indices.val != expected_gather_indices:
return None
if gather_op.axis.val != 0:
return None

# shape op
shape_op = gather_op.x.op
if shape_op.op_type != "shape":
return None
if shape_op.x != target_op:
return None

return mul_op_y.val


def _try_get_upsample_factor_pattern_1(output_size):
op = output_size
# If the output has value, then the upsample op itself is derived from the upsample_1d op,
# so we can just return scale factor 1 for that case
Expand Down Expand Up @@ -103,9 +158,15 @@ def _try_replace_with_core_upsample(op):
assert op.op_type in target_ops

# 2d upsampling
if op.op_type in ["torch_upsample_nearest_neighbor", "torch_upsample_bilinear"]:
scales_h = _try_get_upsample_factor(op.output_height.op)
scales_w = _try_get_upsample_factor(op.output_width.op)
if op.op_type in target_ops:

# try to resolve the scaling factor - pattern 1
scales_h = _try_get_upsample_factor_pattern_1(op.output_height.op)
scales_w = _try_get_upsample_factor_pattern_1(op.output_width.op)

if scales_h is None or scales_w is None:
scales_h = _try_get_upsample_factor_pattern_2(op.output_height.op, 2, op.x)
scales_w = _try_get_upsample_factor_pattern_2(op.output_width.op, 3, op.x)

if scales_h is None or scales_w is None:
return False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def forward(self, x):
assert mlmodel.user_defined_metadata[_METADATA_SOURCE_DIALECT] == dialect_name


@pytest.mark.skipif((version_info.major, version_info.minor) == (3, 13), reason="rdar://158079341")
class TestExecuTorchExamples(TorchBaseTest):
@pytest.mark.parametrize(
"compute_unit, backend, frontend, dynamic",
Expand Down
70 changes: 70 additions & 0 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2668,6 +2668,41 @@ def forward(self, x):
compute_unit=compute_unit,
)

@pytest.mark.parametrize(
"compute_unit, backend, frontend",
itertools.product(compute_units, backends, frontends),
)
def test_upsample_with_shape_gather_pattern(self, compute_unit, backend, frontend):
if frontend == TorchFrontend.TORCHEXPORT:
pytest.xfail("CoreML model not runnable for the torch export frontend.")

input_shape = (1, 3, 32, 32)

class UpsampleBilinear(nn.Module):
def forward(self, x):
b, c, h, w = x.shape
return nn.functional.interpolate(
x, size=(h * 2, w * 2), mode="bilinear", align_corners=False
)

model = UpsampleBilinear().eval()

h_dim = torch.export.Dim(name="height", min=16, max=128)
w_dim = torch.export.Dim(name="width", min=16, max=128)
torch_export_dynamic_shapes = {"x": {2: h_dim, 3: w_dim}}

self.run_compare_torch(
input_shape,
model,
frontend=frontend,
backend=backend,
compute_unit=compute_unit,
converter_input_type=[
TensorType(shape=(1, 3, ct.RangeDim(16, 128), ct.RangeDim(16, 128))),
],
torch_export_dynamic_shapes=torch_export_dynamic_shapes,
)

@pytest.mark.parametrize(
"compute_unit, backend, frontend, output_size",
itertools.product(compute_units, backends, frontends, [10, 170]),
Expand Down Expand Up @@ -5234,7 +5269,42 @@ def forward(self, x, y):
input_shapes, model, compute_unit=compute_unit, backend=backend, frontend=frontend
)

@pytest.mark.parametrize(
"compute_unit, backend, frontend, input_shape",
itertools.product(
compute_units,
backends,
frontends,
[
(ct.RangeDim(3, 21), ),
(15, )
]
),
)
def test_expand_dynamic_shape4(self, compute_unit, backend, frontend, input_shape):
if frontend in TORCH_EXPORT_BASED_FRONTENDS:
pytest.xfail(
"torch.export refuses to make size-1 dim dynamic, "
"and cannot expand one dynamic dimension into another dynamic dimension"
)

class TestModel(nn.Module):
def forward(self, x):
return x.reshape(-1, 1, 3).expand(-1, 7, -1)

converter_input_type = [ct.TensorType(name = "x", shape=input_shape, dtype=types.fp32)]
model = TestModel()

self.run_compare_torch(
torch.rand(15),
model,
input_as_shape=False,
converter_input_type=converter_input_type,
frontend=frontend,
backend=backend,
compute_unit=compute_unit,
)

class TestExpandDims(TorchBaseTest):
@pytest.mark.parametrize(
"compute_unit, backend, frontend, rank_and_axis",
Expand Down
6 changes: 5 additions & 1 deletion coremltools/test/ml_program/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import platform
import shutil
import tempfile
from sys import version_info
from typing import Dict, Tuple

import numpy as np
Expand Down Expand Up @@ -1496,8 +1497,11 @@ def validate_inference(multifunction_mlpackage_path: str) -> None:
shutil.rmtree(multifunction_mlpackage_path)


@pytest.mark.skipif(
(version_info.major, version_info.minor) == (3, 13),
reason="rdar://157488825 (Python 3.13 Unit Test Segmentation Fault)",
)
class TestBisectModel:

@staticmethod
def check_spec_op_type(model_path, expected_ops):
spec = load_spec(model_path)
Expand Down
2 changes: 1 addition & 1 deletion coremltools/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause


__version__ = "9.0b1" # VERSION_STRING
__version__ = "9.0" # VERSION_STRING
14 changes: 7 additions & 7 deletions docs-guides/source/installing-coremltools.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ This page describes how to install the [`coremltools`](https://github.com/apple/

```{admonition} Supported Python and MacOS Versions

The current version of coremltools ([version 8.0](https://github.com/apple/coremltools)) includes wheels for Python 3.7, 3.8, 3.9, 3.10, 3.11, and 3.12. The last stable release of coremltools to support Python 2 is version 4.0.
The current version of coremltools ([version 9.0b1](https://github.com/apple/coremltools)) includes wheels for Python 3.7, 3.8, 3.9, 3.10, 3.11, 3.12, 3.13. The last stable release of coremltools to support Python 2 is version 4.0.

The supported MacOS versions are as follows:

Expand All @@ -19,7 +19,7 @@ The supported MacOS versions are as follows:
If you are using macOS, you should already be familiar with the [Mac Terminal app command line](https://developer.apple.com/library/archive/documentation/OpenSource/Conceptual/ShellScripting/CommandLInePrimer/CommandLine.html#//apple_ref/doc/uid/TP40004268-CH271-BBCBEAJD "Command Line Primer") to perform tasks such as installations and updates. If you are using Linux, you should already be familiar with [basic Shell commands in Linux](https://www.geeksforgeeks.org/basic-shell-commands-in-linux/).
```

Before installing coremltools, you need [Python](https://www.python.org/downloads/ "Python Downloads") and the [`pip`](https://pip.pypa.io/en/stable/) installer.
Before installing coremltools, you need [Python](https://www.python.org/downloads/ "Python Downloads") and the [`pip`](https://pip.pypa.io/en/stable/) installer.

The `coremltools` package supports [Python 3](https://www.python.org/download/releases/3.0/). We recommend that you install Python 3.6 or newer. Use a Python package manager such as [Conda](https://docs.conda.io/en/latest/index.html) or [venv](https://docs.python.org/3/library/venv.html) to install the newest version of Python and other dependencies. [Conda](https://docs.conda.io/en/latest/index.html) is recommended because it is the most reliable way to install all required dependencies.

Expand Down Expand Up @@ -82,7 +82,7 @@ python -m venv coremltools-venv
source coremltools-venv/bin/activate
```

4. Follow the instructions in [Install Core ML Tools](#install-core-ml-tools).
4. Follow the instructions in [Install Core ML Tools](#install-core-ml-tools).

## Install Core ML Tools

Expand All @@ -103,17 +103,17 @@ The continuous integration (CI) system linked to the `coremltools` repo builds a
To access the wheel for a particular `coremltools` release, follow these steps:

1. Go to the [`coremltools` repository](https://github.com/apple/coremltools) on GitHub, scroll down to the **README.md** heading, and click the **build passing** button. The **Branches** tab appears:

![Branches tab](images/repo-readme-build-passing-button-annot.png)

![Branches passed](images/repo-branches-passed-button.png)

2. Click the **passed** button to show the **Pipeline** tab:

![Pipeline tab](images/repo-build-wheel-selected.png)

3. Click a wheel in the **Build** column. For example, in the previous figure, the **build_wheel_macos_py38** wheel is highlighted for clicking. After clicking a wheel, the raw job log appears, with the **Download** and **Browse** buttons in the right column:

![Download and Browse](images/repo-job-artifacts.png)

4. Click the **Download** button to download the `dist` folder with the wheel files.
Expand Down
5 changes: 4 additions & 1 deletion reqs/build.pip
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
numpy==1.21.0; platform_machine == "arm64" and python_version < "3.9"
numpy<1.20; platform_machine != "arm64" and python_version < "3.9"
numpy==2.0.0; python_version >= "3.9"
numpy==2.0.0; python_version >= "3.9" and python_version < "3.13"
numpy==2.1.0; python_version >= "3.13"

protobuf
pytest
setuptools; python_version >= "3.13"
six
sympy
tqdm
Expand Down
4 changes: 2 additions & 2 deletions reqs/common_test_packages.pip
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
boto3==1.14.8; python_version < '3.12'
boto3==1.24.22; python_version >= '3.12'
boto3==1.24.22; python_version == '3.12'
boto3==1.39.3; python_version >= '3.13'

configparser
future

olefile==0.44
pandas
parameterized==0.8.1
protobuf
pillow
pytest==7.1.2
pytest-cov
Expand Down
2 changes: 1 addition & 1 deletion reqs/pytorch.pip
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ torchaudio==2.2.0; platform_machine != "arm64"
torchvision==0.17.0; platform_machine != "arm64"

# Torch dependencies for ARM
torch==2.7.0; platform_machine == "arm64"
torch>=2.2.0,<=2.7.0; platform_machine == "arm64"
torchaudio>=2.2.0; platform_machine == "arm64"
torchvision>=0.17.0; platform_machine == "arm64"
torchsr==1.0.4; platform_machine == "arm64"
Expand Down
2 changes: 1 addition & 1 deletion reqs/test_executorch.pip
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

# Warning: Starting from ExecuTorch 0.6.0, coremltools is added as a dependency
# so we need to re-install built-from-source coremltools after pip install ExecuTorch
executorch>=0.6.0; platform_machine == "arm64" and python_version >= '3.10'
executorch>=0.6.0; platform_machine == "arm64" and python_version >= '3.10' and python_version < '3.13'
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Scientific/Engineering",
"Topic :: Software Development",
],
Expand Down